jinns 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +17 -7
- jinns/data/_AbstractDataGenerator.py +19 -0
- jinns/data/_Batchs.py +31 -12
- jinns/data/_CubicMeshPDENonStatio.py +431 -0
- jinns/data/_CubicMeshPDEStatio.py +464 -0
- jinns/data/_DataGeneratorODE.py +187 -0
- jinns/data/_DataGeneratorObservations.py +189 -0
- jinns/data/_DataGeneratorParameter.py +206 -0
- jinns/data/__init__.py +19 -9
- jinns/data/_utils.py +149 -0
- jinns/experimental/__init__.py +9 -0
- jinns/loss/_DynamicLoss.py +114 -187
- jinns/loss/_DynamicLossAbstract.py +45 -68
- jinns/loss/_LossODE.py +71 -336
- jinns/loss/_LossPDE.py +146 -520
- jinns/loss/__init__.py +28 -6
- jinns/loss/_abstract_loss.py +15 -0
- jinns/loss/_boundary_conditions.py +20 -19
- jinns/loss/_loss_utils.py +78 -159
- jinns/loss/_loss_weights.py +12 -44
- jinns/loss/_operators.py +84 -74
- jinns/nn/__init__.py +15 -0
- jinns/nn/_abstract_pinn.py +22 -0
- jinns/nn/_hyperpinn.py +94 -57
- jinns/nn/_mlp.py +50 -25
- jinns/nn/_pinn.py +33 -19
- jinns/nn/_ppinn.py +70 -34
- jinns/nn/_save_load.py +21 -51
- jinns/nn/_spinn.py +33 -16
- jinns/nn/_spinn_mlp.py +28 -22
- jinns/nn/_utils.py +38 -0
- jinns/parameters/__init__.py +8 -1
- jinns/parameters/_derivative_keys.py +116 -177
- jinns/parameters/_params.py +18 -46
- jinns/plot/__init__.py +2 -0
- jinns/plot/_plot.py +35 -34
- jinns/solver/_rar.py +80 -63
- jinns/solver/_solve.py +89 -63
- jinns/solver/_utils.py +4 -6
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +12 -9
- jinns/utils/_types.py +11 -57
- jinns/utils/_utils.py +4 -11
- jinns/validation/__init__.py +2 -0
- jinns/validation/_validation.py +20 -19
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/METADATA +4 -3
- jinns-1.4.0.dist-info/RECORD +53 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/WHEEL +1 -1
- jinns/data/_DataGenerators.py +0 -1634
- jinns-1.3.0.dist-info/RECORD +0 -44
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/AUTHORS +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info/licenses}/LICENSE +0 -0
- {jinns-1.3.0.dist-info → jinns-1.4.0.dist-info}/top_level.txt +0 -0
jinns/data/_DataGenerators.py
DELETED
|
@@ -1,1634 +0,0 @@
|
|
|
1
|
-
# pylint: disable=unsubscriptable-object
|
|
2
|
-
"""
|
|
3
|
-
Define the DataGenerators modules
|
|
4
|
-
"""
|
|
5
|
-
from __future__ import (
|
|
6
|
-
annotations,
|
|
7
|
-
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
-
import warnings
|
|
9
|
-
from typing import TYPE_CHECKING, Dict
|
|
10
|
-
from dataclasses import InitVar
|
|
11
|
-
import equinox as eqx
|
|
12
|
-
import jax
|
|
13
|
-
import jax.numpy as jnp
|
|
14
|
-
from jaxtyping import Key, Int, PyTree, Array, Float, Bool
|
|
15
|
-
from jinns.data._Batchs import *
|
|
16
|
-
|
|
17
|
-
if TYPE_CHECKING:
|
|
18
|
-
from jinns.utils._types import *
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
|
|
22
|
-
"""
|
|
23
|
-
Utility function that fills the field `batch.param_batch_dict` of a batch object.
|
|
24
|
-
"""
|
|
25
|
-
return eqx.tree_at(
|
|
26
|
-
lambda m: m.param_batch_dict,
|
|
27
|
-
batch,
|
|
28
|
-
param_batch_dict,
|
|
29
|
-
is_leaf=lambda x: x is None,
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
|
|
34
|
-
"""
|
|
35
|
-
Utility function that fills the field `batch.obs_batch_dict` of a batch object
|
|
36
|
-
"""
|
|
37
|
-
return eqx.tree_at(
|
|
38
|
-
lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def make_cartesian_product(
|
|
43
|
-
b1: Float[Array, "batch_size dim1"], b2: Float[Array, "batch_size dim2"]
|
|
44
|
-
) -> Float[Array, "(batch_size*batch_size) (dim1+dim2)"]:
|
|
45
|
-
"""
|
|
46
|
-
Create the cartesian product of a time and a border omega batches
|
|
47
|
-
by tiling and repeating
|
|
48
|
-
"""
|
|
49
|
-
n1 = b1.shape[0]
|
|
50
|
-
n2 = b2.shape[0]
|
|
51
|
-
b1 = jnp.repeat(b1, n2, axis=0)
|
|
52
|
-
b2 = jnp.tile(b2, reps=(n1,) + tuple(1 for i in b2.shape[1:]))
|
|
53
|
-
return jnp.concatenate([b1, b2], axis=1)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def _reset_batch_idx_and_permute(
|
|
57
|
-
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
58
|
-
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
59
|
-
key, domain, curr_idx, _, p = operands
|
|
60
|
-
# resetting counter
|
|
61
|
-
curr_idx = 0
|
|
62
|
-
# reshuffling
|
|
63
|
-
key, subkey = jax.random.split(key)
|
|
64
|
-
if p is None:
|
|
65
|
-
domain = jax.random.permutation(subkey, domain, axis=0, independent=False)
|
|
66
|
-
else:
|
|
67
|
-
# otherwise p is used to avoid collocation points not in n_start
|
|
68
|
-
# NOTE that replace=True to avoid undefined behaviour but then, the
|
|
69
|
-
# domain.shape[0] does not really grow as in the original RAR. instead,
|
|
70
|
-
# it always comprises the same number of points, but the points are
|
|
71
|
-
# updated
|
|
72
|
-
domain = jax.random.choice(
|
|
73
|
-
subkey, domain, shape=(domain.shape[0],), replace=True, p=p
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
# return updated
|
|
77
|
-
return (key, domain, curr_idx)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _increment_batch_idx(
|
|
81
|
-
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
82
|
-
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
83
|
-
key, domain, curr_idx, batch_size, _ = operands
|
|
84
|
-
# simply increases counter and get the batch
|
|
85
|
-
curr_idx += batch_size
|
|
86
|
-
return (key, domain, curr_idx)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def _reset_or_increment(
|
|
90
|
-
bend: Int,
|
|
91
|
-
n_eff: Int,
|
|
92
|
-
operands: tuple[Key, Float[Array, "n dimension"], Int, None, Float[Array, "n"]],
|
|
93
|
-
) -> tuple[Key, Float[Array, "n dimension"], Int]:
|
|
94
|
-
"""
|
|
95
|
-
Factorize the code of the jax.lax.cond which checks if we have seen all the
|
|
96
|
-
batches in an epoch
|
|
97
|
-
If bend > n_eff (ie n when no RAR sampling) we reshuffle and start from 0
|
|
98
|
-
again. Otherwise, if bend < n_eff, this means there are still *_batch_size
|
|
99
|
-
samples at least that have not been seen and we can take a new batch
|
|
100
|
-
|
|
101
|
-
Parameters
|
|
102
|
-
----------
|
|
103
|
-
bend
|
|
104
|
-
An integer. The new hypothetical index for the starting of the batch
|
|
105
|
-
n_eff
|
|
106
|
-
An integer. The number of points to see to complete an epoch
|
|
107
|
-
operands
|
|
108
|
-
A tuple. As passed to _reset_batch_idx_and_permute and
|
|
109
|
-
_increment_batch_idx
|
|
110
|
-
|
|
111
|
-
Returns
|
|
112
|
-
-------
|
|
113
|
-
res
|
|
114
|
-
A tuple as returned by _reset_batch_idx_and_permute or
|
|
115
|
-
_increment_batch_idx
|
|
116
|
-
"""
|
|
117
|
-
return jax.lax.cond(
|
|
118
|
-
bend > n_eff, _reset_batch_idx_and_permute, _increment_batch_idx, operands
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def _check_and_set_rar_parameters(
|
|
123
|
-
rar_parameters: dict, n: Int, n_start: Int
|
|
124
|
-
) -> tuple[Int, Float[Array, "n"], Int, Int]:
|
|
125
|
-
if rar_parameters is not None and n_start is None:
|
|
126
|
-
raise ValueError(
|
|
127
|
-
"n_start must be provided in the context of RAR sampling scheme"
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
if rar_parameters is not None:
|
|
131
|
-
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
132
|
-
# probability to specify non-used collocation points (i.e. points
|
|
133
|
-
# above n_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
134
|
-
p = jnp.zeros((n,))
|
|
135
|
-
p = p.at[:n_start].set(1 / n_start)
|
|
136
|
-
# set internal counter for the number of gradient steps since the
|
|
137
|
-
# last new collocation points have been added
|
|
138
|
-
# It is not 0 to ensure the first iteration of RAR happens just
|
|
139
|
-
# after start_iter. See the _proceed_to_rar() function in _rar.py
|
|
140
|
-
rar_iter_from_last_sampling = rar_parameters["update_every"] - 1
|
|
141
|
-
# set iternal counter for the number of times collocation points
|
|
142
|
-
# have been added
|
|
143
|
-
rar_iter_nb = 0
|
|
144
|
-
else:
|
|
145
|
-
n_start = n
|
|
146
|
-
p = None
|
|
147
|
-
rar_iter_from_last_sampling = None
|
|
148
|
-
rar_iter_nb = None
|
|
149
|
-
|
|
150
|
-
return n_start, p, rar_iter_from_last_sampling, rar_iter_nb
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
class DataGeneratorODE(eqx.Module):
|
|
154
|
-
"""
|
|
155
|
-
A class implementing data generator object for ordinary differential equations.
|
|
156
|
-
|
|
157
|
-
Parameters
|
|
158
|
-
----------
|
|
159
|
-
key : Key
|
|
160
|
-
Jax random key to sample new time points and to shuffle batches
|
|
161
|
-
nt : Int
|
|
162
|
-
The number of total time points that will be divided in
|
|
163
|
-
batches. Batches are made so that each data point is seen only
|
|
164
|
-
once during 1 epoch.
|
|
165
|
-
tmin : float
|
|
166
|
-
The minimum value of the time domain to consider
|
|
167
|
-
tmax : float
|
|
168
|
-
The maximum value of the time domain to consider
|
|
169
|
-
temporal_batch_size : int | None, default=None
|
|
170
|
-
The size of the batch of randomly selected points among
|
|
171
|
-
the `nt` points. If None, no minibatches are used.
|
|
172
|
-
method : str, default="uniform"
|
|
173
|
-
Either `grid` or `uniform`, default is `uniform`.
|
|
174
|
-
The method that generates the `nt` time points. `grid` means
|
|
175
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
176
|
-
sampled points over the domain
|
|
177
|
-
rar_parameters : Dict[str, Int], default=None
|
|
178
|
-
Defaults to None: do not use Residual Adaptative Resampling.
|
|
179
|
-
Otherwise a dictionary with keys
|
|
180
|
-
|
|
181
|
-
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
182
|
-
- `update_every`: the number of gradient steps taken between
|
|
183
|
-
each update of collocation points in the RAR algo.
|
|
184
|
-
- `sample_size`: the size of the sample from which we will select new
|
|
185
|
-
collocation points.
|
|
186
|
-
- `selected_sample_size`: the number of selected
|
|
187
|
-
points from the sample to be added to the current collocation
|
|
188
|
-
points.
|
|
189
|
-
n_start : Int, default=None
|
|
190
|
-
Defaults to None. The effective size of nt used at start time.
|
|
191
|
-
This value must be
|
|
192
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
193
|
-
n_start = nt and this is hidden from the user.
|
|
194
|
-
In RAR, n_start
|
|
195
|
-
then corresponds to the initial number of points we train the PINN.
|
|
196
|
-
"""
|
|
197
|
-
|
|
198
|
-
key: Key = eqx.field(kw_only=True)
|
|
199
|
-
nt: Int = eqx.field(kw_only=True, static=True)
|
|
200
|
-
tmin: Float = eqx.field(kw_only=True)
|
|
201
|
-
tmax: Float = eqx.field(kw_only=True)
|
|
202
|
-
temporal_batch_size: Int | None = eqx.field(static=True, default=None, kw_only=True)
|
|
203
|
-
method: str = eqx.field(
|
|
204
|
-
static=True, kw_only=True, default_factory=lambda: "uniform"
|
|
205
|
-
)
|
|
206
|
-
rar_parameters: Dict[str, Int] = eqx.field(default=None, kw_only=True)
|
|
207
|
-
n_start: Int = eqx.field(static=True, default=None, kw_only=True)
|
|
208
|
-
|
|
209
|
-
# all the init=False fields are set in __post_init__
|
|
210
|
-
p: Float[Array, "nt 1"] = eqx.field(init=False)
|
|
211
|
-
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
212
|
-
rar_iter_nb: Int = eqx.field(init=False)
|
|
213
|
-
curr_time_idx: Int = eqx.field(init=False)
|
|
214
|
-
times: Float[Array, "nt 1"] = eqx.field(init=False)
|
|
215
|
-
|
|
216
|
-
def __post_init__(self):
|
|
217
|
-
(
|
|
218
|
-
self.n_start,
|
|
219
|
-
self.p,
|
|
220
|
-
self.rar_iter_from_last_sampling,
|
|
221
|
-
self.rar_iter_nb,
|
|
222
|
-
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.n_start)
|
|
223
|
-
|
|
224
|
-
if self.temporal_batch_size is not None:
|
|
225
|
-
self.curr_time_idx = self.nt + self.temporal_batch_size
|
|
226
|
-
# to be sure there is a shuffling at first get_batch()
|
|
227
|
-
# NOTE in the extreme case we could do:
|
|
228
|
-
# self.curr_time_idx=jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
229
|
-
# but we do not test for such extreme values. Where we subtract
|
|
230
|
-
# self.temporal_batch_size - 1 because otherwise when computing
|
|
231
|
-
# `bend` we do not want to overflow the max int32 with unwanted behaviour
|
|
232
|
-
else:
|
|
233
|
-
self.curr_time_idx = 0
|
|
234
|
-
|
|
235
|
-
self.key, self.times = self.generate_time_data(self.key)
|
|
236
|
-
# Note that, here, in __init__ (and __post_init__), this is the
|
|
237
|
-
# only place where self assignment are authorized so we do the
|
|
238
|
-
# above way for the key.
|
|
239
|
-
|
|
240
|
-
def sample_in_time_domain(
|
|
241
|
-
self, key: Key, sample_size: Int = None
|
|
242
|
-
) -> Float[Array, "nt 1"]:
|
|
243
|
-
return jax.random.uniform(
|
|
244
|
-
key,
|
|
245
|
-
(self.nt if sample_size is None else sample_size, 1),
|
|
246
|
-
minval=self.tmin,
|
|
247
|
-
maxval=self.tmax,
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
def generate_time_data(self, key: Key) -> tuple[Key, Float[Array, "nt"]]:
|
|
251
|
-
"""
|
|
252
|
-
Construct a complete set of `self.nt` time points according to the
|
|
253
|
-
specified `self.method`
|
|
254
|
-
|
|
255
|
-
Note that self.times has always size self.nt and not self.n_start, even
|
|
256
|
-
in RAR scheme, we must allocate all the collocation points
|
|
257
|
-
"""
|
|
258
|
-
key, subkey = jax.random.split(self.key)
|
|
259
|
-
if self.method == "grid":
|
|
260
|
-
partial_times = (self.tmax - self.tmin) / self.nt
|
|
261
|
-
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
262
|
-
if self.method == "uniform":
|
|
263
|
-
return key, self.sample_in_time_domain(subkey)
|
|
264
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
265
|
-
|
|
266
|
-
def _get_time_operands(
|
|
267
|
-
self,
|
|
268
|
-
) -> tuple[Key, Float[Array, "nt 1"], Int, Int, Float[Array, "nt 1"]]:
|
|
269
|
-
return (
|
|
270
|
-
self.key,
|
|
271
|
-
self.times,
|
|
272
|
-
self.curr_time_idx,
|
|
273
|
-
self.temporal_batch_size,
|
|
274
|
-
self.p,
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
def temporal_batch(
|
|
278
|
-
self,
|
|
279
|
-
) -> tuple["DataGeneratorODE", Float[Array, "temporal_batch_size"]]:
|
|
280
|
-
"""
|
|
281
|
-
Return a batch of time points. If all the batches have been seen, we
|
|
282
|
-
reshuffle them, otherwise we just return the next unseen batch.
|
|
283
|
-
"""
|
|
284
|
-
if self.temporal_batch_size is None or self.temporal_batch_size == self.nt:
|
|
285
|
-
# Avoid unnecessary reshuffling
|
|
286
|
-
return self, self.times
|
|
287
|
-
|
|
288
|
-
bstart = self.curr_time_idx
|
|
289
|
-
bend = bstart + self.temporal_batch_size
|
|
290
|
-
|
|
291
|
-
# Compute the effective number of used collocation points
|
|
292
|
-
if self.rar_parameters is not None:
|
|
293
|
-
nt_eff = (
|
|
294
|
-
self.n_start
|
|
295
|
-
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
|
|
296
|
-
)
|
|
297
|
-
else:
|
|
298
|
-
nt_eff = self.nt
|
|
299
|
-
|
|
300
|
-
new_attributes = _reset_or_increment(bend, nt_eff, self._get_time_operands())
|
|
301
|
-
new = eqx.tree_at(
|
|
302
|
-
lambda m: (m.key, m.times, m.curr_time_idx), self, new_attributes
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
# commands below are equivalent to
|
|
306
|
-
# return self.times[i:(i+t_batch_size)]
|
|
307
|
-
# start indices can be dynamic but the slice shape is fixed
|
|
308
|
-
return new, jax.lax.dynamic_slice(
|
|
309
|
-
new.times,
|
|
310
|
-
start_indices=(new.curr_time_idx, 0),
|
|
311
|
-
slice_sizes=(new.temporal_batch_size, 1),
|
|
312
|
-
)
|
|
313
|
-
|
|
314
|
-
def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
|
|
315
|
-
"""
|
|
316
|
-
Generic method to return a batch. Here we call `self.temporal_batch()`
|
|
317
|
-
"""
|
|
318
|
-
new, temporal_batch = self.temporal_batch()
|
|
319
|
-
return new, ODEBatch(temporal_batch=temporal_batch)
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
class CubicMeshPDEStatio(eqx.Module):
|
|
323
|
-
r"""
|
|
324
|
-
A class implementing data generator object for stationary partial
|
|
325
|
-
differential equations.
|
|
326
|
-
|
|
327
|
-
Parameters
|
|
328
|
-
----------
|
|
329
|
-
key : Key
|
|
330
|
-
Jax random key to sample new time points and to shuffle batches
|
|
331
|
-
n : Int
|
|
332
|
-
The number of total $\Omega$ points that will be divided in
|
|
333
|
-
batches. Batches are made so that each data point is seen only
|
|
334
|
-
once during 1 epoch.
|
|
335
|
-
nb : Int | None
|
|
336
|
-
The total number of points in $\partial\Omega$. Can be None if no
|
|
337
|
-
boundary condition is specified.
|
|
338
|
-
omega_batch_size : Int | None, default=None
|
|
339
|
-
The size of the batch of randomly selected points among
|
|
340
|
-
the `n` points. If None no minibatches are used.
|
|
341
|
-
omega_border_batch_size : Int | None, default=None
|
|
342
|
-
The size of the batch of points randomly selected
|
|
343
|
-
among the `nb` points. If None, `omega_border_batch_size`
|
|
344
|
-
no minibatches are used. In dimension 1,
|
|
345
|
-
minibatches are never used since the boundary is composed of two
|
|
346
|
-
singletons.
|
|
347
|
-
dim : Int
|
|
348
|
-
Dimension of $\Omega$ domain
|
|
349
|
-
min_pts : tuple[tuple[Float, Float], ...]
|
|
350
|
-
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
351
|
-
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
352
|
-
x_{n, min})$
|
|
353
|
-
max_pts : tuple[tuple[Float, Float], ...]
|
|
354
|
-
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
355
|
-
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
356
|
-
x_{n,max})$
|
|
357
|
-
method : str, default="uniform"
|
|
358
|
-
Either `grid` or `uniform`, default is `uniform`.
|
|
359
|
-
The method that generates the `nt` time points. `grid` means
|
|
360
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
361
|
-
sampled points over the domain
|
|
362
|
-
rar_parameters : Dict[str, Int], default=None
|
|
363
|
-
Defaults to None: do not use Residual Adaptative Resampling.
|
|
364
|
-
Otherwise a dictionary with keys
|
|
365
|
-
|
|
366
|
-
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
367
|
-
- `update_every`: the number of gradient steps taken between
|
|
368
|
-
each update of collocation points in the RAR algo.
|
|
369
|
-
- `sample_size`: the size of the sample from which we will select new
|
|
370
|
-
collocation points.
|
|
371
|
-
- `selected_sample_size`: the number of selected
|
|
372
|
-
points from the sample to be added to the current collocation
|
|
373
|
-
points.
|
|
374
|
-
n_start : Int, default=None
|
|
375
|
-
Defaults to None. The effective size of n used at start time.
|
|
376
|
-
This value must be
|
|
377
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
378
|
-
n_start = n and this is hidden from the user.
|
|
379
|
-
In RAR, n_start
|
|
380
|
-
then corresponds to the initial number of points we train the PINN on.
|
|
381
|
-
"""
|
|
382
|
-
|
|
383
|
-
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
384
|
-
key: Key = eqx.field(kw_only=True)
|
|
385
|
-
n: Int = eqx.field(kw_only=True, static=True)
|
|
386
|
-
nb: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
387
|
-
omega_batch_size: Int | None = eqx.field(
|
|
388
|
-
kw_only=True,
|
|
389
|
-
static=True,
|
|
390
|
-
default=None, # can be None as
|
|
391
|
-
# CubicMeshPDENonStatio inherits but also if omega_batch_size=n
|
|
392
|
-
) # static cause used as a
|
|
393
|
-
# shape in jax.lax.dynamic_slice
|
|
394
|
-
omega_border_batch_size: Int | None = eqx.field(
|
|
395
|
-
kw_only=True, static=True, default=None
|
|
396
|
-
) # static cause used as a
|
|
397
|
-
# shape in jax.lax.dynamic_slice
|
|
398
|
-
dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
|
|
399
|
-
# shape in jax.lax.dynamic_slice
|
|
400
|
-
min_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
|
|
401
|
-
max_pts: tuple[tuple[Float, Float], ...] = eqx.field(kw_only=True)
|
|
402
|
-
method: str = eqx.field(
|
|
403
|
-
kw_only=True, static=True, default_factory=lambda: "uniform"
|
|
404
|
-
)
|
|
405
|
-
rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
|
|
406
|
-
n_start: Int = eqx.field(kw_only=True, default=None, static=True)
|
|
407
|
-
|
|
408
|
-
# all the init=False fields are set in __post_init__
|
|
409
|
-
p: Float[Array, "n"] = eqx.field(init=False)
|
|
410
|
-
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
411
|
-
rar_iter_nb: Int = eqx.field(init=False)
|
|
412
|
-
curr_omega_idx: Int = eqx.field(init=False)
|
|
413
|
-
curr_omega_border_idx: Int = eqx.field(init=False)
|
|
414
|
-
omega: Float[Array, "n dim"] = eqx.field(init=False)
|
|
415
|
-
omega_border: Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None = eqx.field(
|
|
416
|
-
init=False
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
def __post_init__(self):
|
|
420
|
-
assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
|
|
421
|
-
assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
|
|
422
|
-
|
|
423
|
-
(
|
|
424
|
-
self.n_start,
|
|
425
|
-
self.p,
|
|
426
|
-
self.rar_iter_from_last_sampling,
|
|
427
|
-
self.rar_iter_nb,
|
|
428
|
-
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
429
|
-
|
|
430
|
-
if self.method == "grid" and self.dim == 2:
|
|
431
|
-
perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
432
|
-
if self.n != perfect_sq:
|
|
433
|
-
warnings.warn(
|
|
434
|
-
"Grid sampling is requested in dimension 2 with a non"
|
|
435
|
-
f" perfect square dataset size (self.n = {self.n})."
|
|
436
|
-
f" Modifying self.n to self.n = {perfect_sq}."
|
|
437
|
-
)
|
|
438
|
-
self.n = perfect_sq
|
|
439
|
-
|
|
440
|
-
if self.nb is not None:
|
|
441
|
-
if self.dim == 1:
|
|
442
|
-
self.omega_border_batch_size = None
|
|
443
|
-
# We are in 1-D case => omega_border_batch_size is
|
|
444
|
-
# ignored since borders of Omega are singletons.
|
|
445
|
-
# self.border_batch() will return [xmin, xmax]
|
|
446
|
-
else:
|
|
447
|
-
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
448
|
-
raise ValueError(
|
|
449
|
-
f"number of border point must be"
|
|
450
|
-
f" a multiple of 2xd = {2*self.dim} (the # of faces of"
|
|
451
|
-
f" a d-dimensional cube). Got {self.nb=}."
|
|
452
|
-
)
|
|
453
|
-
if (
|
|
454
|
-
self.omega_border_batch_size is not None
|
|
455
|
-
and self.nb // (2 * self.dim) < self.omega_border_batch_size
|
|
456
|
-
):
|
|
457
|
-
raise ValueError(
|
|
458
|
-
f"number of points per facets ({self.nb//(2*self.dim)})"
|
|
459
|
-
f" cannot be lower than border batch size "
|
|
460
|
-
f" ({self.omega_border_batch_size})."
|
|
461
|
-
)
|
|
462
|
-
self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
|
|
463
|
-
|
|
464
|
-
if self.omega_batch_size is None:
|
|
465
|
-
self.curr_omega_idx = 0
|
|
466
|
-
else:
|
|
467
|
-
self.curr_omega_idx = self.n + self.omega_batch_size
|
|
468
|
-
# to be sure there is a shuffling at first get_batch()
|
|
469
|
-
|
|
470
|
-
if self.omega_border_batch_size is None:
|
|
471
|
-
self.curr_omega_border_idx = 0
|
|
472
|
-
else:
|
|
473
|
-
self.curr_omega_border_idx = self.nb + self.omega_border_batch_size
|
|
474
|
-
# to be sure there is a shuffling at first get_batch()
|
|
475
|
-
|
|
476
|
-
self.key, self.omega = self.generate_omega_data(self.key)
|
|
477
|
-
self.key, self.omega_border = self.generate_omega_border_data(self.key)
|
|
478
|
-
|
|
479
|
-
def sample_in_omega_domain(
|
|
480
|
-
self, keys: Key, sample_size: Int = None
|
|
481
|
-
) -> Float[Array, "n dim"]:
|
|
482
|
-
sample_size = self.n if sample_size is None else sample_size
|
|
483
|
-
if self.dim == 1:
|
|
484
|
-
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
485
|
-
return jax.random.uniform(
|
|
486
|
-
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
487
|
-
)
|
|
488
|
-
# keys = jax.random.split(key, self.dim)
|
|
489
|
-
return jnp.concatenate(
|
|
490
|
-
[
|
|
491
|
-
jax.random.uniform(
|
|
492
|
-
keys[i],
|
|
493
|
-
(sample_size, 1),
|
|
494
|
-
minval=self.min_pts[i],
|
|
495
|
-
maxval=self.max_pts[i],
|
|
496
|
-
)
|
|
497
|
-
for i in range(self.dim)
|
|
498
|
-
],
|
|
499
|
-
axis=-1,
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
def sample_in_omega_border_domain(
|
|
503
|
-
self, keys: Key, sample_size: int = None
|
|
504
|
-
) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
|
|
505
|
-
sample_size = self.nb if sample_size is None else sample_size
|
|
506
|
-
if sample_size is None:
|
|
507
|
-
return None
|
|
508
|
-
if self.dim == 1:
|
|
509
|
-
xmin = self.min_pts[0]
|
|
510
|
-
xmax = self.max_pts[0]
|
|
511
|
-
return jnp.array([xmin, xmax]).astype(float)
|
|
512
|
-
if self.dim == 2:
|
|
513
|
-
# currently hard-coded the 4 edges for d==2
|
|
514
|
-
# TODO : find a general & efficient way to sample from the border
|
|
515
|
-
# (facets) of the hypercube in general dim.
|
|
516
|
-
facet_n = sample_size // (2 * self.dim)
|
|
517
|
-
xmin = jnp.hstack(
|
|
518
|
-
[
|
|
519
|
-
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
520
|
-
jax.random.uniform(
|
|
521
|
-
keys[0],
|
|
522
|
-
(facet_n, 1),
|
|
523
|
-
minval=self.min_pts[1],
|
|
524
|
-
maxval=self.max_pts[1],
|
|
525
|
-
),
|
|
526
|
-
]
|
|
527
|
-
)
|
|
528
|
-
xmax = jnp.hstack(
|
|
529
|
-
[
|
|
530
|
-
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
531
|
-
jax.random.uniform(
|
|
532
|
-
keys[1],
|
|
533
|
-
(facet_n, 1),
|
|
534
|
-
minval=self.min_pts[1],
|
|
535
|
-
maxval=self.max_pts[1],
|
|
536
|
-
),
|
|
537
|
-
]
|
|
538
|
-
)
|
|
539
|
-
ymin = jnp.hstack(
|
|
540
|
-
[
|
|
541
|
-
jax.random.uniform(
|
|
542
|
-
keys[2],
|
|
543
|
-
(facet_n, 1),
|
|
544
|
-
minval=self.min_pts[0],
|
|
545
|
-
maxval=self.max_pts[0],
|
|
546
|
-
),
|
|
547
|
-
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
548
|
-
]
|
|
549
|
-
)
|
|
550
|
-
ymax = jnp.hstack(
|
|
551
|
-
[
|
|
552
|
-
jax.random.uniform(
|
|
553
|
-
keys[3],
|
|
554
|
-
(facet_n, 1),
|
|
555
|
-
minval=self.min_pts[0],
|
|
556
|
-
maxval=self.max_pts[0],
|
|
557
|
-
),
|
|
558
|
-
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
559
|
-
]
|
|
560
|
-
)
|
|
561
|
-
return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
562
|
-
raise NotImplementedError(
|
|
563
|
-
"Generation of the border of a cube in dimension > 2 is not "
|
|
564
|
-
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
565
|
-
)
|
|
566
|
-
|
|
567
|
-
def generate_omega_data(self, key: Key, data_size: int = None) -> tuple[
|
|
568
|
-
Key,
|
|
569
|
-
Float[Array, "n dim"],
|
|
570
|
-
]:
|
|
571
|
-
r"""
|
|
572
|
-
Construct a complete set of `self.n` $\Omega$ points according to the
|
|
573
|
-
specified `self.method`.
|
|
574
|
-
"""
|
|
575
|
-
data_size = self.n if data_size is None else data_size
|
|
576
|
-
# Generate Omega
|
|
577
|
-
if self.method == "grid":
|
|
578
|
-
if self.dim == 1:
|
|
579
|
-
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
580
|
-
## shape (n, 1)
|
|
581
|
-
omega = jnp.linspace(xmin, xmax, data_size)[:, None]
|
|
582
|
-
else:
|
|
583
|
-
xyz_ = jnp.meshgrid(
|
|
584
|
-
*[
|
|
585
|
-
jnp.linspace(
|
|
586
|
-
self.min_pts[i],
|
|
587
|
-
self.max_pts[i],
|
|
588
|
-
int(jnp.round(jnp.sqrt(data_size))),
|
|
589
|
-
)
|
|
590
|
-
for i in range(self.dim)
|
|
591
|
-
]
|
|
592
|
-
)
|
|
593
|
-
xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
|
|
594
|
-
omega = jnp.concatenate(xyz_, axis=-1)
|
|
595
|
-
elif self.method == "uniform":
|
|
596
|
-
if self.dim == 1:
|
|
597
|
-
key, subkeys = jax.random.split(key, 2)
|
|
598
|
-
else:
|
|
599
|
-
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
600
|
-
omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
|
|
601
|
-
else:
|
|
602
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
603
|
-
return key, omega
|
|
604
|
-
|
|
605
|
-
def generate_omega_border_data(self, key: Key, data_size: int = None) -> tuple[
|
|
606
|
-
Key,
|
|
607
|
-
Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
|
|
608
|
-
]:
|
|
609
|
-
r"""
|
|
610
|
-
Also constructs a complete set of `self.nb`
|
|
611
|
-
$\partial\Omega$ points if `self.omega_border_batch_size` is not
|
|
612
|
-
`None`. If the latter is `None` we set `self.omega_border` to `None`.
|
|
613
|
-
"""
|
|
614
|
-
# Generate border of omega
|
|
615
|
-
data_size = self.nb if data_size is None else data_size
|
|
616
|
-
if self.dim == 2:
|
|
617
|
-
key, *subkeys = jax.random.split(key, 5)
|
|
618
|
-
else:
|
|
619
|
-
subkeys = None
|
|
620
|
-
omega_border = self.sample_in_omega_border_domain(
|
|
621
|
-
subkeys, sample_size=data_size
|
|
622
|
-
)
|
|
623
|
-
|
|
624
|
-
return key, omega_border
|
|
625
|
-
|
|
626
|
-
def _get_omega_operands(
|
|
627
|
-
self,
|
|
628
|
-
) -> tuple[Key, Float[Array, "n dim"], Int, Int, Float[Array, "n"]]:
|
|
629
|
-
return (
|
|
630
|
-
self.key,
|
|
631
|
-
self.omega,
|
|
632
|
-
self.curr_omega_idx,
|
|
633
|
-
self.omega_batch_size,
|
|
634
|
-
self.p,
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
def inside_batch(
|
|
638
|
-
self,
|
|
639
|
-
) -> tuple["CubicMeshPDEStatio", Float[Array, "omega_batch_size dim"]]:
|
|
640
|
-
r"""
|
|
641
|
-
Return a batch of points in $\Omega$.
|
|
642
|
-
If all the batches have been seen, we reshuffle them,
|
|
643
|
-
otherwise we just return the next unseen batch.
|
|
644
|
-
"""
|
|
645
|
-
if self.omega_batch_size is None or self.omega_batch_size == self.n:
|
|
646
|
-
# Avoid unnecessary reshuffling
|
|
647
|
-
return self, self.omega
|
|
648
|
-
|
|
649
|
-
# Compute the effective number of used collocation points
|
|
650
|
-
if self.rar_parameters is not None:
|
|
651
|
-
n_eff = (
|
|
652
|
-
self.n_start
|
|
653
|
-
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
|
|
654
|
-
)
|
|
655
|
-
else:
|
|
656
|
-
n_eff = self.n
|
|
657
|
-
|
|
658
|
-
bstart = self.curr_omega_idx
|
|
659
|
-
bend = bstart + self.omega_batch_size
|
|
660
|
-
|
|
661
|
-
new_attributes = _reset_or_increment(bend, n_eff, self._get_omega_operands())
|
|
662
|
-
new = eqx.tree_at(
|
|
663
|
-
lambda m: (m.key, m.omega, m.curr_omega_idx), self, new_attributes
|
|
664
|
-
)
|
|
665
|
-
|
|
666
|
-
return new, jax.lax.dynamic_slice(
|
|
667
|
-
new.omega,
|
|
668
|
-
start_indices=(new.curr_omega_idx, 0),
|
|
669
|
-
slice_sizes=(new.omega_batch_size, new.dim),
|
|
670
|
-
)
|
|
671
|
-
|
|
672
|
-
def _get_omega_border_operands(
|
|
673
|
-
self,
|
|
674
|
-
) -> tuple[
|
|
675
|
-
Key, Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None, Int, Int, None
|
|
676
|
-
]:
|
|
677
|
-
return (
|
|
678
|
-
self.key,
|
|
679
|
-
self.omega_border,
|
|
680
|
-
self.curr_omega_border_idx,
|
|
681
|
-
self.omega_border_batch_size,
|
|
682
|
-
None,
|
|
683
|
-
)
|
|
684
|
-
|
|
685
|
-
def border_batch(
|
|
686
|
-
self,
|
|
687
|
-
) -> tuple[
|
|
688
|
-
"CubicMeshPDEStatio",
|
|
689
|
-
Float[Array, "1 1 2"] | Float[Array, "omega_border_batch_size 2 4"] | None,
|
|
690
|
-
]:
|
|
691
|
-
r"""
|
|
692
|
-
Return
|
|
693
|
-
|
|
694
|
-
- The value `None` if `self.omega_border_batch_size` is `None`.
|
|
695
|
-
|
|
696
|
-
- a jnp array with two fixed values $(x_{min}, x_{max})$ if
|
|
697
|
-
`self.dim` = 1. There is no sampling here, we return the entire
|
|
698
|
-
$\partial\Omega$
|
|
699
|
-
|
|
700
|
-
- a batch of points in $\partial\Omega$ otherwise, stacked by
|
|
701
|
-
facet on the last axis.
|
|
702
|
-
If all the batches have been seen, we reshuffle them,
|
|
703
|
-
otherwise we just return the next unseen batch.
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
"""
|
|
707
|
-
if self.nb is None:
|
|
708
|
-
# Avoid unnecessary reshuffling
|
|
709
|
-
return self, None
|
|
710
|
-
|
|
711
|
-
if self.dim == 1:
|
|
712
|
-
# Avoid unnecessary reshuffling
|
|
713
|
-
# 1-D case, no randomness : we always return the whole omega border,
|
|
714
|
-
# i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
|
|
715
|
-
return self, self.omega_border[None, None] # shape is (1, 1, 2)
|
|
716
|
-
|
|
717
|
-
if (
|
|
718
|
-
self.omega_border_batch_size is None
|
|
719
|
-
or self.omega_border_batch_size == self.nb // 2**self.dim
|
|
720
|
-
):
|
|
721
|
-
# Avoid unnecessary reshuffling
|
|
722
|
-
return self, self.omega_border
|
|
723
|
-
|
|
724
|
-
bstart = self.curr_omega_border_idx
|
|
725
|
-
bend = bstart + self.omega_border_batch_size
|
|
726
|
-
|
|
727
|
-
new_attributes = _reset_or_increment(
|
|
728
|
-
bend, self.nb, self._get_omega_border_operands()
|
|
729
|
-
)
|
|
730
|
-
new = eqx.tree_at(
|
|
731
|
-
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
|
|
732
|
-
self,
|
|
733
|
-
new_attributes,
|
|
734
|
-
)
|
|
735
|
-
|
|
736
|
-
return new, jax.lax.dynamic_slice(
|
|
737
|
-
new.omega_border,
|
|
738
|
-
start_indices=(new.curr_omega_border_idx, 0, 0),
|
|
739
|
-
slice_sizes=(new.omega_border_batch_size, new.dim, 2 * new.dim),
|
|
740
|
-
)
|
|
741
|
-
|
|
742
|
-
def get_batch(self) -> tuple["CubicMeshPDEStatio", PDEStatioBatch]:
|
|
743
|
-
"""
|
|
744
|
-
Generic method to return a batch. Here we call `self.inside_batch()`
|
|
745
|
-
and `self.border_batch()`
|
|
746
|
-
"""
|
|
747
|
-
new, inside_batch = self.inside_batch()
|
|
748
|
-
new, border_batch = new.border_batch()
|
|
749
|
-
return new, PDEStatioBatch(domain_batch=inside_batch, border_batch=border_batch)
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
753
|
-
r"""
|
|
754
|
-
A class implementing data generator object for non stationary partial
|
|
755
|
-
differential equations. Formally, it extends `CubicMeshPDEStatio`
|
|
756
|
-
to include a temporal batch.
|
|
757
|
-
|
|
758
|
-
Parameters
|
|
759
|
-
----------
|
|
760
|
-
key : Key
|
|
761
|
-
Jax random key to sample new time points and to shuffle batches
|
|
762
|
-
n : Int
|
|
763
|
-
The number of total $I\times \Omega$ points that will be divided in
|
|
764
|
-
batches. Batches are made so that each data point is seen only
|
|
765
|
-
once during 1 epoch.
|
|
766
|
-
nb : Int | None
|
|
767
|
-
The total number of points in $\partial\Omega$. Can be None if no
|
|
768
|
-
boundary condition is specified.
|
|
769
|
-
ni : Int
|
|
770
|
-
The number of total $\Omega$ points at $t=0$ that will be divided in
|
|
771
|
-
batches. Batches are made so that each data point is seen only
|
|
772
|
-
once during 1 epoch.
|
|
773
|
-
domain_batch_size : Int | None, default=None
|
|
774
|
-
The size of the batch of randomly selected points among
|
|
775
|
-
the `n` points. If None no mini-batches are used.
|
|
776
|
-
border_batch_size : Int | None, default=None
|
|
777
|
-
The size of the batch of points randomly selected
|
|
778
|
-
among the `nb` points. If None, `domain_batch_size` no
|
|
779
|
-
mini-batches are used.
|
|
780
|
-
initial_batch_size : Int | None, default=None
|
|
781
|
-
The size of the batch of randomly selected points among
|
|
782
|
-
the `ni` points. If None no
|
|
783
|
-
mini-batches are used.
|
|
784
|
-
dim : Int
|
|
785
|
-
An integer. Dimension of $\Omega$ domain.
|
|
786
|
-
min_pts : tuple[tuple[Float, Float], ...]
|
|
787
|
-
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
788
|
-
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
789
|
-
x_{n, min})$
|
|
790
|
-
max_pts : tuple[tuple[Float, Float], ...]
|
|
791
|
-
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
792
|
-
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
793
|
-
x_{n,max})$
|
|
794
|
-
tmin : float
|
|
795
|
-
The minimum value of the time domain to consider
|
|
796
|
-
tmax : float
|
|
797
|
-
The maximum value of the time domain to consider
|
|
798
|
-
method : str, default="uniform"
|
|
799
|
-
Either `grid` or `uniform`, default is `uniform`.
|
|
800
|
-
The method that generates the `nt` time points. `grid` means
|
|
801
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
802
|
-
sampled points over the domain
|
|
803
|
-
rar_parameters : Dict[str, Int], default=None
|
|
804
|
-
Defaults to None: do not use Residual Adaptative Resampling.
|
|
805
|
-
Otherwise a dictionary with keys
|
|
806
|
-
|
|
807
|
-
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
808
|
-
- `update_every`: the number of gradient steps taken between
|
|
809
|
-
each update of collocation points in the RAR algo.
|
|
810
|
-
- `sample_size`: the size of the sample from which we will select new
|
|
811
|
-
collocation points.
|
|
812
|
-
- `selected_sample_size`: the number of selected
|
|
813
|
-
points from the sample to be added to the current collocation
|
|
814
|
-
points.
|
|
815
|
-
n_start : Int, default=None
|
|
816
|
-
Defaults to None. The effective size of n used at start time.
|
|
817
|
-
This value must be
|
|
818
|
-
provided when rar_parameters is not None. Otherwise we set internally
|
|
819
|
-
n_start = n and this is hidden from the user.
|
|
820
|
-
In RAR, n_start
|
|
821
|
-
then corresponds to the initial number of omega points we train the PINN.
|
|
822
|
-
"""
|
|
823
|
-
|
|
824
|
-
tmin: Float = eqx.field(kw_only=True)
|
|
825
|
-
tmax: Float = eqx.field(kw_only=True)
|
|
826
|
-
ni: Int = eqx.field(kw_only=True, static=True)
|
|
827
|
-
domain_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
828
|
-
initial_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
829
|
-
border_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
830
|
-
|
|
831
|
-
curr_domain_idx: Int = eqx.field(init=False)
|
|
832
|
-
curr_initial_idx: Int = eqx.field(init=False)
|
|
833
|
-
curr_border_idx: Int = eqx.field(init=False)
|
|
834
|
-
domain: Float[Array, "n 1+dim"] = eqx.field(init=False)
|
|
835
|
-
border: Float[Array, "(nb//2) 1+1 2"] | Float[Array, "(nb//4) 2+1 4"] | None = (
|
|
836
|
-
eqx.field(init=False)
|
|
837
|
-
)
|
|
838
|
-
initial: Float[Array, "ni dim"] = eqx.field(init=False)
|
|
839
|
-
|
|
840
|
-
def __post_init__(self):
|
|
841
|
-
"""
|
|
842
|
-
Note that neither __init__ or __post_init__ are called when udating a
|
|
843
|
-
Module with eqx.tree_at!
|
|
844
|
-
"""
|
|
845
|
-
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
846
|
-
# class is not automatically called
|
|
847
|
-
|
|
848
|
-
if self.method == "grid":
|
|
849
|
-
# NOTE we must redo the sampling with the square root number of samples
|
|
850
|
-
# and then take the cartesian product
|
|
851
|
-
self.n = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
852
|
-
if self.dim == 2:
|
|
853
|
-
# in the case of grid sampling in 2D in dim 2 in non-statio,
|
|
854
|
-
# self.n needs to be a perfect ^4, because there is the
|
|
855
|
-
# cartesian product with time domain which is also present
|
|
856
|
-
perfect_4 = int(jnp.round(self.n**0.25) ** 4)
|
|
857
|
-
if self.n != perfect_4:
|
|
858
|
-
warnings.warn(
|
|
859
|
-
"Grid sampling is requested in dimension 2 in non"
|
|
860
|
-
" stationary setting with a non"
|
|
861
|
-
f" perfect square dataset size (self.n = {self.n})."
|
|
862
|
-
f" Modifying self.n to self.n = {perfect_4}."
|
|
863
|
-
)
|
|
864
|
-
self.n = perfect_4
|
|
865
|
-
self.key, half_domain_times = self.generate_time_data(
|
|
866
|
-
self.key, int(jnp.round(jnp.sqrt(self.n)))
|
|
867
|
-
)
|
|
868
|
-
|
|
869
|
-
self.key, half_domain_omega = self.generate_omega_data(
|
|
870
|
-
self.key, data_size=int(jnp.round(jnp.sqrt(self.n)))
|
|
871
|
-
)
|
|
872
|
-
self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
|
|
873
|
-
|
|
874
|
-
# NOTE
|
|
875
|
-
(
|
|
876
|
-
self.n_start,
|
|
877
|
-
self.p,
|
|
878
|
-
self.rar_iter_from_last_sampling,
|
|
879
|
-
self.rar_iter_nb,
|
|
880
|
-
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
881
|
-
elif self.method == "uniform":
|
|
882
|
-
self.key, domain_times = self.generate_time_data(self.key, self.n)
|
|
883
|
-
self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
|
|
884
|
-
else:
|
|
885
|
-
raise ValueError(
|
|
886
|
-
f"Bad value for method. Got {self.method}, expected"
|
|
887
|
-
' "grid" or "uniform"'
|
|
888
|
-
)
|
|
889
|
-
|
|
890
|
-
if self.domain_batch_size is None:
|
|
891
|
-
self.curr_domain_idx = 0
|
|
892
|
-
else:
|
|
893
|
-
self.curr_domain_idx = self.n + self.domain_batch_size
|
|
894
|
-
# to be sure there is a shuffling at first get_batch()
|
|
895
|
-
if self.nb is not None:
|
|
896
|
-
# the check below has already been done in super.__post_init__ if
|
|
897
|
-
# dim > 1. Here we retest it in whatever dim
|
|
898
|
-
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
899
|
-
raise ValueError(
|
|
900
|
-
"number of border point must be"
|
|
901
|
-
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
902
|
-
)
|
|
903
|
-
# the check below concern omega_border_batch_size for dim > 1 in
|
|
904
|
-
# super.__post_init__. Here it concerns all dim values since our
|
|
905
|
-
# border_batch is the concatenation or cartesian product with times
|
|
906
|
-
if (
|
|
907
|
-
self.border_batch_size is not None
|
|
908
|
-
and self.nb // (2 * self.dim) < self.border_batch_size
|
|
909
|
-
):
|
|
910
|
-
raise ValueError(
|
|
911
|
-
"number of points per facets (nb//2*self.dim)"
|
|
912
|
-
" cannot be lower than border batch size"
|
|
913
|
-
)
|
|
914
|
-
self.key, boundary_times = self.generate_time_data(
|
|
915
|
-
self.key, self.nb // (2 * self.dim)
|
|
916
|
-
)
|
|
917
|
-
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
918
|
-
boundary_times = jnp.repeat(
|
|
919
|
-
boundary_times, self.omega_border.shape[-1], axis=2
|
|
920
|
-
)
|
|
921
|
-
if self.dim == 1:
|
|
922
|
-
self.border = make_cartesian_product(
|
|
923
|
-
boundary_times, self.omega_border[None, None]
|
|
924
|
-
)
|
|
925
|
-
else:
|
|
926
|
-
self.border = jnp.concatenate(
|
|
927
|
-
[boundary_times, self.omega_border], axis=1
|
|
928
|
-
)
|
|
929
|
-
if self.border_batch_size is None:
|
|
930
|
-
self.curr_border_idx = 0
|
|
931
|
-
else:
|
|
932
|
-
self.curr_border_idx = self.nb + self.border_batch_size
|
|
933
|
-
# to be sure there is a shuffling at first get_batch()
|
|
934
|
-
|
|
935
|
-
else:
|
|
936
|
-
self.border = None
|
|
937
|
-
self.curr_border_idx = None
|
|
938
|
-
self.border_batch_size = None
|
|
939
|
-
|
|
940
|
-
if self.ni is not None:
|
|
941
|
-
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
942
|
-
if self.ni != perfect_sq:
|
|
943
|
-
warnings.warn(
|
|
944
|
-
"Grid sampling is requested in dimension 2 with a non"
|
|
945
|
-
f" perfect square dataset size (self.ni = {self.ni})."
|
|
946
|
-
f" Modifying self.ni to self.ni = {perfect_sq}."
|
|
947
|
-
)
|
|
948
|
-
self.ni = perfect_sq
|
|
949
|
-
self.key, self.initial = self.generate_omega_data(
|
|
950
|
-
self.key, data_size=self.ni
|
|
951
|
-
)
|
|
952
|
-
|
|
953
|
-
if self.initial_batch_size is None or self.initial_batch_size == self.ni:
|
|
954
|
-
self.curr_initial_idx = 0
|
|
955
|
-
else:
|
|
956
|
-
self.curr_initial_idx = self.ni + self.initial_batch_size
|
|
957
|
-
# to be sure there is a shuffling at first get_batch()
|
|
958
|
-
else:
|
|
959
|
-
self.initial = None
|
|
960
|
-
self.initial_batch_size = None
|
|
961
|
-
self.curr_initial_idx = None
|
|
962
|
-
|
|
963
|
-
# the following attributes will not be used anymore
|
|
964
|
-
self.omega = None
|
|
965
|
-
self.omega_border = None
|
|
966
|
-
|
|
967
|
-
def generate_time_data(self, key: Key, nt: Int) -> tuple[Key, Float[Array, "nt 1"]]:
|
|
968
|
-
"""
|
|
969
|
-
Construct a complete set of `nt` time points according to the
|
|
970
|
-
specified `self.method`
|
|
971
|
-
"""
|
|
972
|
-
key, subkey = jax.random.split(key, 2)
|
|
973
|
-
if self.method == "grid":
|
|
974
|
-
partial_times = (self.tmax - self.tmin) / nt
|
|
975
|
-
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
976
|
-
if self.method == "uniform":
|
|
977
|
-
return key, self.sample_in_time_domain(subkey, nt)
|
|
978
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
979
|
-
|
|
980
|
-
def sample_in_time_domain(self, key: Key, nt: Int) -> Float[Array, "nt 1"]:
|
|
981
|
-
return jax.random.uniform(
|
|
982
|
-
key,
|
|
983
|
-
(nt, 1),
|
|
984
|
-
minval=self.tmin,
|
|
985
|
-
maxval=self.tmax,
|
|
986
|
-
)
|
|
987
|
-
|
|
988
|
-
def _get_domain_operands(
|
|
989
|
-
self,
|
|
990
|
-
) -> tuple[Key, Float[Array, "n 1+dim"], Int, Int, None]:
|
|
991
|
-
return (
|
|
992
|
-
self.key,
|
|
993
|
-
self.domain,
|
|
994
|
-
self.curr_domain_idx,
|
|
995
|
-
self.domain_batch_size,
|
|
996
|
-
self.p,
|
|
997
|
-
)
|
|
998
|
-
|
|
999
|
-
def domain_batch(
|
|
1000
|
-
self,
|
|
1001
|
-
) -> tuple["CubicMeshPDEStatio", Float[Array, "domain_batch_size 1+dim"]]:
|
|
1002
|
-
|
|
1003
|
-
if self.domain_batch_size is None or self.domain_batch_size == self.n:
|
|
1004
|
-
# Avoid unnecessary reshuffling
|
|
1005
|
-
return self, self.domain
|
|
1006
|
-
|
|
1007
|
-
bstart = self.curr_domain_idx
|
|
1008
|
-
bend = bstart + self.domain_batch_size
|
|
1009
|
-
|
|
1010
|
-
# Compute the effective number of used collocation points
|
|
1011
|
-
if self.rar_parameters is not None:
|
|
1012
|
-
n_eff = (
|
|
1013
|
-
self.n_start
|
|
1014
|
-
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
|
|
1015
|
-
)
|
|
1016
|
-
else:
|
|
1017
|
-
n_eff = self.n
|
|
1018
|
-
|
|
1019
|
-
new_attributes = _reset_or_increment(bend, n_eff, self._get_domain_operands())
|
|
1020
|
-
new = eqx.tree_at(
|
|
1021
|
-
lambda m: (m.key, m.domain, m.curr_domain_idx),
|
|
1022
|
-
self,
|
|
1023
|
-
new_attributes,
|
|
1024
|
-
)
|
|
1025
|
-
return new, jax.lax.dynamic_slice(
|
|
1026
|
-
new.domain,
|
|
1027
|
-
start_indices=(new.curr_domain_idx, 0),
|
|
1028
|
-
slice_sizes=(new.domain_batch_size, new.dim + 1),
|
|
1029
|
-
)
|
|
1030
|
-
|
|
1031
|
-
def _get_border_operands(
|
|
1032
|
-
self,
|
|
1033
|
-
) -> tuple[
|
|
1034
|
-
Key, Float[Array, "nb 1+1 2"] | Float[Array, "(nb//4) 2+1 4"], Int, Int, None
|
|
1035
|
-
]:
|
|
1036
|
-
return (
|
|
1037
|
-
self.key,
|
|
1038
|
-
self.border,
|
|
1039
|
-
self.curr_border_idx,
|
|
1040
|
-
self.border_batch_size,
|
|
1041
|
-
None,
|
|
1042
|
-
)
|
|
1043
|
-
|
|
1044
|
-
def border_batch(
|
|
1045
|
-
self,
|
|
1046
|
-
) -> tuple[
|
|
1047
|
-
"CubicMeshPDENonStatio",
|
|
1048
|
-
Float[Array, "border_batch_size 1+1 2"]
|
|
1049
|
-
| Float[Array, "border_batch_size 2+1 4"]
|
|
1050
|
-
| None,
|
|
1051
|
-
]:
|
|
1052
|
-
if self.nb is None:
|
|
1053
|
-
# Avoid unnecessary reshuffling
|
|
1054
|
-
return self, None
|
|
1055
|
-
|
|
1056
|
-
if (
|
|
1057
|
-
self.border_batch_size is None
|
|
1058
|
-
or self.border_batch_size == self.nb // 2**self.dim
|
|
1059
|
-
):
|
|
1060
|
-
# Avoid unnecessary reshuffling
|
|
1061
|
-
return self, self.border
|
|
1062
|
-
|
|
1063
|
-
bstart = self.curr_border_idx
|
|
1064
|
-
bend = bstart + self.border_batch_size
|
|
1065
|
-
|
|
1066
|
-
n_eff = self.border.shape[0]
|
|
1067
|
-
|
|
1068
|
-
new_attributes = _reset_or_increment(bend, n_eff, self._get_border_operands())
|
|
1069
|
-
new = eqx.tree_at(
|
|
1070
|
-
lambda m: (m.key, m.border, m.curr_border_idx),
|
|
1071
|
-
self,
|
|
1072
|
-
new_attributes,
|
|
1073
|
-
)
|
|
1074
|
-
|
|
1075
|
-
return new, jax.lax.dynamic_slice(
|
|
1076
|
-
new.border,
|
|
1077
|
-
start_indices=(new.curr_border_idx, 0, 0),
|
|
1078
|
-
slice_sizes=(
|
|
1079
|
-
new.border_batch_size,
|
|
1080
|
-
new.dim + 1,
|
|
1081
|
-
2 * new.dim,
|
|
1082
|
-
),
|
|
1083
|
-
)
|
|
1084
|
-
|
|
1085
|
-
def _get_initial_operands(
|
|
1086
|
-
self,
|
|
1087
|
-
) -> tuple[Key, Float[Array, "ni dim"], Int, Int, None]:
|
|
1088
|
-
return (
|
|
1089
|
-
self.key,
|
|
1090
|
-
self.initial,
|
|
1091
|
-
self.curr_initial_idx,
|
|
1092
|
-
self.initial_batch_size,
|
|
1093
|
-
None,
|
|
1094
|
-
)
|
|
1095
|
-
|
|
1096
|
-
def initial_batch(
|
|
1097
|
-
self,
|
|
1098
|
-
) -> tuple["CubicMeshPDEStatio", Float[Array, "initial_batch_size dim"]]:
|
|
1099
|
-
if self.initial_batch_size is None or self.initial_batch_size == self.ni:
|
|
1100
|
-
# Avoid unnecessary reshuffling
|
|
1101
|
-
return self, self.initial
|
|
1102
|
-
|
|
1103
|
-
bstart = self.curr_initial_idx
|
|
1104
|
-
bend = bstart + self.initial_batch_size
|
|
1105
|
-
|
|
1106
|
-
n_eff = self.ni
|
|
1107
|
-
|
|
1108
|
-
new_attributes = _reset_or_increment(bend, n_eff, self._get_initial_operands())
|
|
1109
|
-
new = eqx.tree_at(
|
|
1110
|
-
lambda m: (m.key, m.initial, m.curr_initial_idx),
|
|
1111
|
-
self,
|
|
1112
|
-
new_attributes,
|
|
1113
|
-
)
|
|
1114
|
-
return new, jax.lax.dynamic_slice(
|
|
1115
|
-
new.initial,
|
|
1116
|
-
start_indices=(new.curr_initial_idx, 0),
|
|
1117
|
-
slice_sizes=(new.initial_batch_size, new.dim),
|
|
1118
|
-
)
|
|
1119
|
-
|
|
1120
|
-
def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
|
|
1121
|
-
"""
|
|
1122
|
-
Generic method to return a batch. Here we call `self.domain_batch()`,
|
|
1123
|
-
`self.border_batch()` and `self.initial_batch()`
|
|
1124
|
-
"""
|
|
1125
|
-
new, domain = self.domain_batch()
|
|
1126
|
-
if self.border is not None:
|
|
1127
|
-
new, border = new.border_batch()
|
|
1128
|
-
else:
|
|
1129
|
-
border = None
|
|
1130
|
-
if self.initial is not None:
|
|
1131
|
-
new, initial = new.initial_batch()
|
|
1132
|
-
else:
|
|
1133
|
-
initial = None
|
|
1134
|
-
|
|
1135
|
-
return new, PDENonStatioBatch(
|
|
1136
|
-
domain_batch=domain, border_batch=border, initial_batch=initial
|
|
1137
|
-
)
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
class DataGeneratorObservations(eqx.Module):
|
|
1141
|
-
r"""
|
|
1142
|
-
Despite the class name, it is rather a dataloader for user-provided
|
|
1143
|
-
observations which will are used in the observations loss.
|
|
1144
|
-
|
|
1145
|
-
Parameters
|
|
1146
|
-
----------
|
|
1147
|
-
key : Key
|
|
1148
|
-
Jax random key to shuffle batches
|
|
1149
|
-
obs_batch_size : Int | None
|
|
1150
|
-
The size of the batch of randomly selected points among
|
|
1151
|
-
the `n` points. If None, no minibatch are used.
|
|
1152
|
-
observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
|
|
1153
|
-
Observed values corresponding to the input of the PINN
|
|
1154
|
-
(eg. the time at which we recorded the observations). The first
|
|
1155
|
-
dimension must corresponds to the number of observed_values.
|
|
1156
|
-
The second dimension depends on the input dimension of the PINN,
|
|
1157
|
-
that is `1` for ODE, `n_dim_x` for stationnary PDE and `n_dim_x + 1`
|
|
1158
|
-
for non-stationnary PDE.
|
|
1159
|
-
observed_values : Float[Array, "n_obs, nb_pinn_out"]
|
|
1160
|
-
Observed values that the PINN should learn to fit. The first
|
|
1161
|
-
dimension must be aligned with observed_pinn_in.
|
|
1162
|
-
observed_eq_params : Dict[str, Float[Array, "n_obs 1"]], default={}
|
|
1163
|
-
A dict with keys corresponding to
|
|
1164
|
-
the parameter name. The keys must match the keys in
|
|
1165
|
-
`params["eq_params"]`. The values are jnp.array with 2 dimensions
|
|
1166
|
-
with values corresponding to the parameter value for which we also
|
|
1167
|
-
have observed_pinn_in and observed_values. Hence the first
|
|
1168
|
-
dimension must be aligned with observed_pinn_in and observed_values.
|
|
1169
|
-
Optional argument.
|
|
1170
|
-
sharding_device : jax.sharding.Sharding, default=None
|
|
1171
|
-
Default None. An optional sharding object to constraint the storage
|
|
1172
|
-
of observed inputs, values and parameters. Typically, a
|
|
1173
|
-
SingleDeviceSharding(cpu_device) to avoid loading on GPU huge
|
|
1174
|
-
datasets of observations. Note that computations for **batches**
|
|
1175
|
-
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
1176
|
-
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
1177
|
-
arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
|
|
1178
|
-
"""
|
|
1179
|
-
|
|
1180
|
-
key: Key
|
|
1181
|
-
obs_batch_size: Int | None = eqx.field(static=True)
|
|
1182
|
-
observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
|
|
1183
|
-
observed_values: Float[Array, "n_obs nb_pinn_out"]
|
|
1184
|
-
observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
|
|
1185
|
-
static=True, default_factory=lambda: {}
|
|
1186
|
-
)
|
|
1187
|
-
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
1188
|
-
|
|
1189
|
-
n: Int = eqx.field(init=False, static=True)
|
|
1190
|
-
curr_idx: Int = eqx.field(init=False)
|
|
1191
|
-
indices: Array = eqx.field(init=False)
|
|
1192
|
-
|
|
1193
|
-
def __post_init__(self):
|
|
1194
|
-
if self.observed_pinn_in.shape[0] != self.observed_values.shape[0]:
|
|
1195
|
-
raise ValueError(
|
|
1196
|
-
"self.observed_pinn_in and self.observed_values must have same first axis"
|
|
1197
|
-
)
|
|
1198
|
-
for _, v in self.observed_eq_params.items():
|
|
1199
|
-
if v.shape[0] != self.observed_pinn_in.shape[0]:
|
|
1200
|
-
raise ValueError(
|
|
1201
|
-
"self.observed_pinn_in and the values of"
|
|
1202
|
-
" self.observed_eq_params must have the same first axis"
|
|
1203
|
-
)
|
|
1204
|
-
if len(self.observed_pinn_in.shape) == 1:
|
|
1205
|
-
self.observed_pinn_in = self.observed_pinn_in[:, None]
|
|
1206
|
-
if len(self.observed_pinn_in.shape) > 2:
|
|
1207
|
-
raise ValueError("self.observed_pinn_in must have 2 dimensions")
|
|
1208
|
-
if len(self.observed_values.shape) == 1:
|
|
1209
|
-
self.observed_values = self.observed_values[:, None]
|
|
1210
|
-
if len(self.observed_values.shape) > 2:
|
|
1211
|
-
raise ValueError("self.observed_values must have 2 dimensions")
|
|
1212
|
-
for k, v in self.observed_eq_params.items():
|
|
1213
|
-
if len(v.shape) == 1:
|
|
1214
|
-
self.observed_eq_params[k] = v[:, None]
|
|
1215
|
-
if len(v.shape) > 2:
|
|
1216
|
-
raise ValueError(
|
|
1217
|
-
"Each value of observed_eq_params must have 2 dimensions"
|
|
1218
|
-
)
|
|
1219
|
-
|
|
1220
|
-
self.n = self.observed_pinn_in.shape[0]
|
|
1221
|
-
|
|
1222
|
-
if self.sharding_device is not None:
|
|
1223
|
-
self.observed_pinn_in = jax.lax.with_sharding_constraint(
|
|
1224
|
-
self.observed_pinn_in, self.sharding_device
|
|
1225
|
-
)
|
|
1226
|
-
self.observed_values = jax.lax.with_sharding_constraint(
|
|
1227
|
-
self.observed_values, self.sharding_device
|
|
1228
|
-
)
|
|
1229
|
-
self.observed_eq_params = jax.lax.with_sharding_constraint(
|
|
1230
|
-
self.observed_eq_params, self.sharding_device
|
|
1231
|
-
)
|
|
1232
|
-
|
|
1233
|
-
if self.obs_batch_size is not None:
|
|
1234
|
-
self.curr_idx = self.n + self.obs_batch_size
|
|
1235
|
-
# to be sure there is a shuffling at first get_batch()
|
|
1236
|
-
else:
|
|
1237
|
-
self.curr_idx = 0
|
|
1238
|
-
# For speed and to avoid duplicating data what is really
|
|
1239
|
-
# shuffled is a vector of indices
|
|
1240
|
-
if self.sharding_device is not None:
|
|
1241
|
-
self.indices = jax.lax.with_sharding_constraint(
|
|
1242
|
-
jnp.arange(self.n), self.sharding_device
|
|
1243
|
-
)
|
|
1244
|
-
else:
|
|
1245
|
-
self.indices = jnp.arange(self.n)
|
|
1246
|
-
|
|
1247
|
-
# recall post_init is the only place with _init_ where we can set
|
|
1248
|
-
# self attribute in a in-place way
|
|
1249
|
-
self.key, _ = jax.random.split(self.key, 2) # to make it equivalent to
|
|
1250
|
-
# the call to _reset_batch_idx_and_permute in legacy DG
|
|
1251
|
-
|
|
1252
|
-
def _get_operands(self) -> tuple[Key, Int[Array, "n"], Int, Int, None]:
|
|
1253
|
-
return (
|
|
1254
|
-
self.key,
|
|
1255
|
-
self.indices,
|
|
1256
|
-
self.curr_idx,
|
|
1257
|
-
self.obs_batch_size,
|
|
1258
|
-
None,
|
|
1259
|
-
)
|
|
1260
|
-
|
|
1261
|
-
def obs_batch(
|
|
1262
|
-
self,
|
|
1263
|
-
) -> tuple[
|
|
1264
|
-
"DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
|
|
1265
|
-
]:
|
|
1266
|
-
"""
|
|
1267
|
-
Return a dictionary with (keys, values): (pinn_in, a mini batch of pinn
|
|
1268
|
-
inputs), (obs, a mini batch of corresponding observations), (eq_params,
|
|
1269
|
-
a dictionary with entry names found in `params["eq_params"]` and values
|
|
1270
|
-
giving the correspond parameter value for the couple
|
|
1271
|
-
(input, observation) mentioned before).
|
|
1272
|
-
It can also be a dictionary of dictionaries as described above if
|
|
1273
|
-
observed_pinn_in, observed_values, etc. are dictionaries with keys
|
|
1274
|
-
representing the PINNs.
|
|
1275
|
-
"""
|
|
1276
|
-
if self.obs_batch_size is None or self.obs_batch_size == self.n:
|
|
1277
|
-
# Avoid unnecessary reshuffling
|
|
1278
|
-
return self, {
|
|
1279
|
-
"pinn_in": self.observed_pinn_in,
|
|
1280
|
-
"val": self.observed_values,
|
|
1281
|
-
"eq_params": self.observed_eq_params,
|
|
1282
|
-
}
|
|
1283
|
-
|
|
1284
|
-
new_attributes = _reset_or_increment(
|
|
1285
|
-
self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
|
|
1286
|
-
)
|
|
1287
|
-
new = eqx.tree_at(
|
|
1288
|
-
lambda m: (m.key, m.indices, m.curr_idx), self, new_attributes
|
|
1289
|
-
)
|
|
1290
|
-
|
|
1291
|
-
minib_indices = jax.lax.dynamic_slice(
|
|
1292
|
-
new.indices,
|
|
1293
|
-
start_indices=(new.curr_idx,),
|
|
1294
|
-
slice_sizes=(new.obs_batch_size,),
|
|
1295
|
-
)
|
|
1296
|
-
|
|
1297
|
-
obs_batch = {
|
|
1298
|
-
"pinn_in": jnp.take(
|
|
1299
|
-
new.observed_pinn_in, minib_indices, unique_indices=True, axis=0
|
|
1300
|
-
),
|
|
1301
|
-
"val": jnp.take(
|
|
1302
|
-
new.observed_values, minib_indices, unique_indices=True, axis=0
|
|
1303
|
-
),
|
|
1304
|
-
"eq_params": jax.tree_util.tree_map(
|
|
1305
|
-
lambda a: jnp.take(a, minib_indices, unique_indices=True, axis=0),
|
|
1306
|
-
new.observed_eq_params,
|
|
1307
|
-
),
|
|
1308
|
-
}
|
|
1309
|
-
return new, obs_batch
|
|
1310
|
-
|
|
1311
|
-
def get_batch(
|
|
1312
|
-
self,
|
|
1313
|
-
) -> tuple[
|
|
1314
|
-
"DataGeneratorObservations", Dict[str, Float[Array, "obs_batch_size dim"]]
|
|
1315
|
-
]:
|
|
1316
|
-
"""
|
|
1317
|
-
Generic method to return a batch
|
|
1318
|
-
"""
|
|
1319
|
-
return self.obs_batch()
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
class DataGeneratorParameter(eqx.Module):
|
|
1323
|
-
r"""
|
|
1324
|
-
A data generator for additional unidimensional equation parameter(s).
|
|
1325
|
-
Mostly useful for metamodeling where batch of `params.eq_params` are fed
|
|
1326
|
-
to the network.
|
|
1327
|
-
|
|
1328
|
-
Parameters
|
|
1329
|
-
----------
|
|
1330
|
-
keys : Key | Dict[str, Key]
|
|
1331
|
-
Jax random key to sample new time points and to shuffle batches
|
|
1332
|
-
or a dict of Jax random keys with key entries from param_ranges
|
|
1333
|
-
n : Int
|
|
1334
|
-
The number of total points that will be divided in
|
|
1335
|
-
batches. Batches are made so that each data point is seen only
|
|
1336
|
-
once during 1 epoch.
|
|
1337
|
-
param_batch_size : Int | None, default=None
|
|
1338
|
-
The size of the batch of randomly selected points among
|
|
1339
|
-
the `n` points. **Important**: no check is performed but
|
|
1340
|
-
`param_batch_size` must be the same as other collocation points
|
|
1341
|
-
batch_size (time, space or timexspace depending on the context). This is because we vmap the network on all its axes at once to compute the MSE. Also, `param_batch_size` will be the same for all parameters. If None, no mini-batches are used.
|
|
1342
|
-
param_ranges : Dict[str, tuple[Float, Float] | None, default={}
|
|
1343
|
-
A dict. A dict of tuples (min, max), which
|
|
1344
|
-
reprensents the range of real numbers where to sample batches (of
|
|
1345
|
-
length `param_batch_size` among `n` points).
|
|
1346
|
-
The key corresponds to the parameter name. The keys must match the
|
|
1347
|
-
keys in `params["eq_params"]`.
|
|
1348
|
-
By providing several entries in this dictionary we can sample
|
|
1349
|
-
an arbitrary number of parameters.
|
|
1350
|
-
**Note** that we currently only support unidimensional parameters.
|
|
1351
|
-
This argument can be None if we use `user_data`.
|
|
1352
|
-
method : str, default="uniform"
|
|
1353
|
-
Either `grid` or `uniform`, default is `uniform`. `grid` means
|
|
1354
|
-
regularly spaced points over the domain. `uniform` means uniformly
|
|
1355
|
-
sampled points over the domain
|
|
1356
|
-
user_data : Dict[str, Float[jnp.ndarray, "n"]] | None, default={}
|
|
1357
|
-
A dictionary containing user-provided data for parameters.
|
|
1358
|
-
The keys corresponds to the parameter name,
|
|
1359
|
-
and must match the keys in `params["eq_params"]`. Only
|
|
1360
|
-
unidimensional `jnp.array` are supported. Therefore, the array at
|
|
1361
|
-
`user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1362
|
-
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1363
|
-
priority goes for the content in `user_data`.
|
|
1364
|
-
Defaults to None.
|
|
1365
|
-
"""
|
|
1366
|
-
|
|
1367
|
-
keys: Key | Dict[str, Key]
|
|
1368
|
-
n: Int = eqx.field(static=True)
|
|
1369
|
-
param_batch_size: Int | None = eqx.field(static=True, default=None)
|
|
1370
|
-
param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
|
|
1371
|
-
static=True, default_factory=lambda: {}
|
|
1372
|
-
)
|
|
1373
|
-
method: str = eqx.field(static=True, default="uniform")
|
|
1374
|
-
user_data: Dict[str, Float[onp.Array, "n"]] | None = eqx.field(
|
|
1375
|
-
default_factory=lambda: {}
|
|
1376
|
-
)
|
|
1377
|
-
|
|
1378
|
-
curr_param_idx: Dict[str, Int] = eqx.field(init=False)
|
|
1379
|
-
param_n_samples: Dict[str, Array] = eqx.field(init=False)
|
|
1380
|
-
|
|
1381
|
-
def __post_init__(self):
|
|
1382
|
-
if self.user_data is None:
|
|
1383
|
-
self.user_data = {}
|
|
1384
|
-
if self.param_ranges is None:
|
|
1385
|
-
self.param_ranges = {}
|
|
1386
|
-
if self.n < self.param_batch_size:
|
|
1387
|
-
raise ValueError(
|
|
1388
|
-
f"Number of data points ({self.n}) is smaller than the"
|
|
1389
|
-
f"number of batch points ({self.param_batch_size})."
|
|
1390
|
-
)
|
|
1391
|
-
if not isinstance(self.keys, dict):
|
|
1392
|
-
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1393
|
-
self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
|
|
1394
|
-
|
|
1395
|
-
if self.param_batch_size is None:
|
|
1396
|
-
self.curr_param_idx = None
|
|
1397
|
-
else:
|
|
1398
|
-
self.curr_param_idx = {}
|
|
1399
|
-
for k in self.keys.keys():
|
|
1400
|
-
self.curr_param_idx[k] = self.n + self.param_batch_size
|
|
1401
|
-
# to be sure there is a shuffling at first get_batch()
|
|
1402
|
-
|
|
1403
|
-
# The call to self.generate_data() creates
|
|
1404
|
-
# the dict self.param_n_samples and then we will only use this one
|
|
1405
|
-
# because it merges the scattered data between `user_data` and
|
|
1406
|
-
# `param_ranges`
|
|
1407
|
-
self.keys, self.param_n_samples = self.generate_data(self.keys)
|
|
1408
|
-
|
|
1409
|
-
def generate_data(
|
|
1410
|
-
self, keys: Dict[str, Key]
|
|
1411
|
-
) -> tuple[Dict[str, Key], Dict[str, Float[Array, "n"]]]:
|
|
1412
|
-
"""
|
|
1413
|
-
Generate parameter samples, either through generation
|
|
1414
|
-
or using user-provided data.
|
|
1415
|
-
"""
|
|
1416
|
-
param_n_samples = {}
|
|
1417
|
-
|
|
1418
|
-
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1419
|
-
for k in all_keys:
|
|
1420
|
-
if (
|
|
1421
|
-
self.user_data
|
|
1422
|
-
and k in self.user_data.keys() # pylint: disable=no-member
|
|
1423
|
-
):
|
|
1424
|
-
if self.user_data[k].shape == (self.n, 1):
|
|
1425
|
-
param_n_samples[k] = self.user_data[k]
|
|
1426
|
-
if self.user_data[k].shape == (self.n,):
|
|
1427
|
-
param_n_samples[k] = self.user_data[k][:, None]
|
|
1428
|
-
else:
|
|
1429
|
-
raise ValueError(
|
|
1430
|
-
"Wrong shape for user provided parameters"
|
|
1431
|
-
f" in user_data dictionary at key='{k}'"
|
|
1432
|
-
)
|
|
1433
|
-
else:
|
|
1434
|
-
if self.method == "grid":
|
|
1435
|
-
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1436
|
-
partial = (xmax - xmin) / self.n
|
|
1437
|
-
# shape (n, 1)
|
|
1438
|
-
param_n_samples[k] = jnp.arange(xmin, xmax, partial)[:, None]
|
|
1439
|
-
elif self.method == "uniform":
|
|
1440
|
-
xmin, xmax = self.param_ranges[k][0], self.param_ranges[k][1]
|
|
1441
|
-
keys[k], subkey = jax.random.split(keys[k], 2)
|
|
1442
|
-
param_n_samples[k] = jax.random.uniform(
|
|
1443
|
-
subkey, shape=(self.n, 1), minval=xmin, maxval=xmax
|
|
1444
|
-
)
|
|
1445
|
-
else:
|
|
1446
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
1447
|
-
|
|
1448
|
-
return keys, param_n_samples
|
|
1449
|
-
|
|
1450
|
-
def _get_param_operands(
|
|
1451
|
-
self, k: str
|
|
1452
|
-
) -> tuple[Key, Float[Array, "n"], Int, Int, None]:
|
|
1453
|
-
return (
|
|
1454
|
-
self.keys[k],
|
|
1455
|
-
self.param_n_samples[k],
|
|
1456
|
-
self.curr_param_idx[k],
|
|
1457
|
-
self.param_batch_size,
|
|
1458
|
-
None,
|
|
1459
|
-
)
|
|
1460
|
-
|
|
1461
|
-
def param_batch(self):
|
|
1462
|
-
"""
|
|
1463
|
-
Return a dictionary with batches of parameters
|
|
1464
|
-
If all the batches have been seen, we reshuffle them,
|
|
1465
|
-
otherwise we just return the next unseen batch.
|
|
1466
|
-
"""
|
|
1467
|
-
|
|
1468
|
-
if self.param_batch_size is None or self.param_batch_size == self.n:
|
|
1469
|
-
return self, self.param_n_samples
|
|
1470
|
-
|
|
1471
|
-
def _reset_or_increment_wrapper(param_k, idx_k, key_k):
|
|
1472
|
-
return _reset_or_increment(
|
|
1473
|
-
idx_k + self.param_batch_size,
|
|
1474
|
-
self.n,
|
|
1475
|
-
(key_k, param_k, idx_k, self.param_batch_size, None),
|
|
1476
|
-
)
|
|
1477
|
-
|
|
1478
|
-
res = jax.tree_util.tree_map(
|
|
1479
|
-
_reset_or_increment_wrapper,
|
|
1480
|
-
self.param_n_samples,
|
|
1481
|
-
self.curr_param_idx,
|
|
1482
|
-
self.keys,
|
|
1483
|
-
)
|
|
1484
|
-
# we must transpose the pytrees because keys are merged in res
|
|
1485
|
-
# https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#transposing-trees
|
|
1486
|
-
new_attributes = jax.tree_util.tree_transpose(
|
|
1487
|
-
jax.tree_util.tree_structure(self.keys),
|
|
1488
|
-
jax.tree_util.tree_structure([0, 0, 0]),
|
|
1489
|
-
res,
|
|
1490
|
-
)
|
|
1491
|
-
|
|
1492
|
-
new = eqx.tree_at(
|
|
1493
|
-
lambda m: (m.keys, m.param_n_samples, m.curr_param_idx),
|
|
1494
|
-
self,
|
|
1495
|
-
new_attributes,
|
|
1496
|
-
)
|
|
1497
|
-
|
|
1498
|
-
return new, jax.tree_util.tree_map(
|
|
1499
|
-
lambda p, q: jax.lax.dynamic_slice(
|
|
1500
|
-
p, start_indices=(q, 0), slice_sizes=(new.param_batch_size, 1)
|
|
1501
|
-
),
|
|
1502
|
-
new.param_n_samples,
|
|
1503
|
-
new.curr_param_idx,
|
|
1504
|
-
)
|
|
1505
|
-
|
|
1506
|
-
def get_batch(self):
|
|
1507
|
-
"""
|
|
1508
|
-
Generic method to return a batch
|
|
1509
|
-
"""
|
|
1510
|
-
return self.param_batch()
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
class DataGeneratorObservationsMultiPINNs(eqx.Module):
|
|
1514
|
-
r"""
|
|
1515
|
-
Despite the class name, it is rather a dataloader from user provided
|
|
1516
|
-
observations that will be used for the observations loss.
|
|
1517
|
-
This is the DataGenerator to use when dealing with multiple PINNs
|
|
1518
|
-
(`u_dict`) in SystemLossODE/SystemLossPDE
|
|
1519
|
-
|
|
1520
|
-
Technically, the constraint on the observations in SystemLossXDE are
|
|
1521
|
-
applied in `constraints_system_loss_apply` and in this case the
|
|
1522
|
-
`batch.obs_batch_dict` is a dict of obs_batch_dict over which the tree_map
|
|
1523
|
-
applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
|
|
1524
|
-
|
|
1525
|
-
Parameters
|
|
1526
|
-
----------
|
|
1527
|
-
obs_batch_size : Int
|
|
1528
|
-
The size of the batch of randomly selected observations
|
|
1529
|
-
`obs_batch_size` will be the same for all the
|
|
1530
|
-
elements of the obs dict.
|
|
1531
|
-
observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1532
|
-
A dict of observed_pinn_in as defined in DataGeneratorObservations.
|
|
1533
|
-
Keys must be that of `u_dict`.
|
|
1534
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1535
|
-
corresponding key must still exist in observed_pinn_in_dict with
|
|
1536
|
-
value None
|
|
1537
|
-
observed_values_dict : Dict[str, Float[Array, "n_obs, nb_pinn_out"] | None]
|
|
1538
|
-
A dict of observed_values as defined in DataGeneratorObservations.
|
|
1539
|
-
Keys must be that of `u_dict`.
|
|
1540
|
-
If no observation exists for a particular entry of `u_dict` the
|
|
1541
|
-
corresponding key must still exist in observed_values_dict with
|
|
1542
|
-
value None
|
|
1543
|
-
observed_eq_params_dict : Dict[str, Dict[str, Float[Array, "n_obs 1"]]]
|
|
1544
|
-
A dict of observed_eq_params as defined in DataGeneratorObservations.
|
|
1545
|
-
Keys must be that of `u_dict`.
|
|
1546
|
-
**Note**: if no observation exists for a particular entry of `u_dict` the
|
|
1547
|
-
corresponding key must still exist in observed_eq_params_dict with
|
|
1548
|
-
value `{}` (empty dictionnary).
|
|
1549
|
-
key
|
|
1550
|
-
Jax random key to shuffle batches.
|
|
1551
|
-
"""
|
|
1552
|
-
|
|
1553
|
-
obs_batch_size: Int
|
|
1554
|
-
observed_pinn_in_dict: Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1555
|
-
observed_values_dict: Dict[str, Float[Array, "n_obs nb_pinn_out"] | None]
|
|
1556
|
-
observed_eq_params_dict: Dict[str, Dict[str, Float[Array, "n_obs 1"]]] = eqx.field(
|
|
1557
|
-
default=None, kw_only=True
|
|
1558
|
-
)
|
|
1559
|
-
key: InitVar[Key]
|
|
1560
|
-
|
|
1561
|
-
data_gen_obs: Dict[str, "DataGeneratorObservations"] = eqx.field(init=False)
|
|
1562
|
-
|
|
1563
|
-
def __post_init__(self, key):
|
|
1564
|
-
if self.observed_pinn_in_dict is None or self.observed_values_dict is None:
|
|
1565
|
-
raise ValueError(
|
|
1566
|
-
"observed_pinn_in_dict and observed_values_dict " "must be provided"
|
|
1567
|
-
)
|
|
1568
|
-
if self.observed_pinn_in_dict.keys() != self.observed_values_dict.keys():
|
|
1569
|
-
raise ValueError(
|
|
1570
|
-
"Keys must be the same in observed_pinn_in_dict"
|
|
1571
|
-
" and observed_values_dict"
|
|
1572
|
-
)
|
|
1573
|
-
|
|
1574
|
-
if self.observed_eq_params_dict is None:
|
|
1575
|
-
self.observed_eq_params_dict = {
|
|
1576
|
-
k: {} for k in self.observed_pinn_in_dict.keys()
|
|
1577
|
-
}
|
|
1578
|
-
elif self.observed_pinn_in_dict.keys() != self.observed_eq_params_dict.keys():
|
|
1579
|
-
raise ValueError(
|
|
1580
|
-
f"Keys must be the same in observed_eq_params_dict"
|
|
1581
|
-
f" and observed_pinn_in_dict and observed_values_dict"
|
|
1582
|
-
)
|
|
1583
|
-
|
|
1584
|
-
keys = dict(
|
|
1585
|
-
zip(
|
|
1586
|
-
self.observed_pinn_in_dict.keys(),
|
|
1587
|
-
jax.random.split(key, len(self.observed_pinn_in_dict)),
|
|
1588
|
-
)
|
|
1589
|
-
)
|
|
1590
|
-
self.data_gen_obs = jax.tree_util.tree_map(
|
|
1591
|
-
lambda k, pinn_in, val, eq_params: (
|
|
1592
|
-
DataGeneratorObservations(
|
|
1593
|
-
k, self.obs_batch_size, pinn_in, val, eq_params
|
|
1594
|
-
)
|
|
1595
|
-
if pinn_in is not None
|
|
1596
|
-
else None
|
|
1597
|
-
),
|
|
1598
|
-
keys,
|
|
1599
|
-
self.observed_pinn_in_dict,
|
|
1600
|
-
self.observed_values_dict,
|
|
1601
|
-
self.observed_eq_params_dict,
|
|
1602
|
-
)
|
|
1603
|
-
|
|
1604
|
-
def obs_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
|
|
1605
|
-
"""
|
|
1606
|
-
Returns a dictionary of DataGeneratorObservations.obs_batch with keys
|
|
1607
|
-
from `u_dict`
|
|
1608
|
-
"""
|
|
1609
|
-
data_gen_and_batch_pytree = jax.tree_util.tree_map(
|
|
1610
|
-
lambda a: a.get_batch() if a is not None else {},
|
|
1611
|
-
self.data_gen_obs,
|
|
1612
|
-
is_leaf=lambda x: isinstance(x, DataGeneratorObservations),
|
|
1613
|
-
) # note the is_leaf note to traverse the DataGeneratorObservations and
|
|
1614
|
-
# thus to be able to call the method on the element(s) of
|
|
1615
|
-
# self.data_gen_obs which are not None
|
|
1616
|
-
new_attribute = jax.tree_util.tree_map(
|
|
1617
|
-
lambda a: a[0],
|
|
1618
|
-
data_gen_and_batch_pytree,
|
|
1619
|
-
is_leaf=lambda x: isinstance(x, tuple),
|
|
1620
|
-
)
|
|
1621
|
-
new = eqx.tree_at(lambda m: m.data_gen_obs, self, new_attribute)
|
|
1622
|
-
batches = jax.tree_util.tree_map(
|
|
1623
|
-
lambda a: a[1],
|
|
1624
|
-
data_gen_and_batch_pytree,
|
|
1625
|
-
is_leaf=lambda x: isinstance(x, tuple),
|
|
1626
|
-
)
|
|
1627
|
-
|
|
1628
|
-
return new, batches
|
|
1629
|
-
|
|
1630
|
-
def get_batch(self) -> tuple["DataGeneratorObservationsMultiPINNs", PyTree]:
|
|
1631
|
-
"""
|
|
1632
|
-
Generic method to return a batch
|
|
1633
|
-
"""
|
|
1634
|
-
return self.obs_batch()
|