jinns 1.0.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_Batchs.py +4 -8
- jinns/data/_DataGenerators.py +532 -341
- jinns/loss/_DynamicLoss.py +150 -173
- jinns/loss/_DynamicLossAbstract.py +27 -73
- jinns/loss/_LossODE.py +45 -26
- jinns/loss/_LossPDE.py +85 -84
- jinns/loss/__init__.py +7 -6
- jinns/loss/_boundary_conditions.py +148 -279
- jinns/loss/_loss_utils.py +85 -58
- jinns/loss/_operators.py +441 -184
- jinns/parameters/_derivative_keys.py +487 -60
- jinns/plot/_plot.py +111 -98
- jinns/solver/_rar.py +102 -407
- jinns/solver/_solve.py +73 -38
- jinns/solver/_utils.py +122 -0
- jinns/utils/__init__.py +2 -0
- jinns/utils/_containers.py +3 -1
- jinns/utils/_hyperpinn.py +17 -7
- jinns/utils/_pinn.py +17 -27
- jinns/utils/_ppinn.py +227 -0
- jinns/utils/_save_load.py +13 -13
- jinns/utils/_spinn.py +24 -43
- jinns/utils/_types.py +1 -0
- jinns/utils/_utils.py +40 -12
- jinns-1.2.0.dist-info/AUTHORS +2 -0
- jinns-1.2.0.dist-info/METADATA +127 -0
- jinns-1.2.0.dist-info/RECORD +41 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/WHEEL +1 -1
- jinns-1.0.0.dist-info/METADATA +0 -84
- jinns-1.0.0.dist-info/RECORD +0 -38
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/LICENSE +0 -0
- {jinns-1.0.0.dist-info → jinns-1.2.0.dist-info}/top_level.txt +0 -0
jinns/data/_DataGenerators.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
# pylint: disable=unsubscriptable-object
|
|
2
2
|
"""
|
|
3
|
-
Define the
|
|
3
|
+
Define the DataGenerators modules
|
|
4
4
|
"""
|
|
5
5
|
from __future__ import (
|
|
6
6
|
annotations,
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
|
-
|
|
8
|
+
import warnings
|
|
9
9
|
from typing import TYPE_CHECKING, Dict
|
|
10
10
|
from dataclasses import InitVar
|
|
11
11
|
import equinox as eqx
|
|
@@ -20,8 +20,7 @@ if TYPE_CHECKING:
|
|
|
20
20
|
|
|
21
21
|
def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
|
|
22
22
|
"""
|
|
23
|
-
Utility function that
|
|
24
|
-
param_batch_dict
|
|
23
|
+
Utility function that fills the field `batch.param_batch_dict` of a batch object.
|
|
25
24
|
"""
|
|
26
25
|
return eqx.tree_at(
|
|
27
26
|
lambda m: m.param_batch_dict,
|
|
@@ -33,8 +32,7 @@ def append_param_batch(batch: AnyBatch, param_batch_dict: dict) -> AnyBatch:
|
|
|
33
32
|
|
|
34
33
|
def append_obs_batch(batch: AnyBatch, obs_batch_dict: dict) -> AnyBatch:
|
|
35
34
|
"""
|
|
36
|
-
Utility function that
|
|
37
|
-
obs_batch_dict
|
|
35
|
+
Utility function that fills the field `batch.obs_batch_dict` of a batch object
|
|
38
36
|
"""
|
|
39
37
|
return eqx.tree_at(
|
|
40
38
|
lambda m: m.obs_batch_dict, batch, obs_batch_dict, is_leaf=lambda x: x is None
|
|
@@ -63,12 +61,17 @@ def _reset_batch_idx_and_permute(
|
|
|
63
61
|
curr_idx = 0
|
|
64
62
|
# reshuffling
|
|
65
63
|
key, subkey = jax.random.split(key)
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
64
|
+
if p is None:
|
|
65
|
+
domain = jax.random.permutation(subkey, domain, axis=0, independent=False)
|
|
66
|
+
else:
|
|
67
|
+
# otherwise p is used to avoid collocation points not in n_start
|
|
68
|
+
# NOTE that replace=True to avoid undefined behaviour but then, the
|
|
69
|
+
# domain.shape[0] does not really grow as in the original RAR. instead,
|
|
70
|
+
# it always comprises the same number of points, but the points are
|
|
71
|
+
# updated
|
|
72
|
+
domain = jax.random.choice(
|
|
73
|
+
subkey, domain, shape=(domain.shape[0],), replace=True, p=p
|
|
74
|
+
)
|
|
72
75
|
|
|
73
76
|
# return updated
|
|
74
77
|
return (key, domain, curr_idx)
|
|
@@ -121,13 +124,13 @@ def _check_and_set_rar_parameters(
|
|
|
121
124
|
) -> tuple[Int, Float[Array, "n"], Int, Int]:
|
|
122
125
|
if rar_parameters is not None and n_start is None:
|
|
123
126
|
raise ValueError(
|
|
124
|
-
"
|
|
127
|
+
"n_start must be provided in the context of RAR sampling scheme"
|
|
125
128
|
)
|
|
126
129
|
|
|
127
130
|
if rar_parameters is not None:
|
|
128
131
|
# Default p is None. However, in the RAR sampling scheme we use 0
|
|
129
132
|
# probability to specify non-used collocation points (i.e. points
|
|
130
|
-
# above
|
|
133
|
+
# above n_start). Thus, p is a vector of probability of shape (nt, 1).
|
|
131
134
|
p = jnp.zeros((n,))
|
|
132
135
|
p = p.at[:n_start].set(1 / n_start)
|
|
133
136
|
# set internal counter for the number of gradient steps since the
|
|
@@ -163,81 +166,83 @@ class DataGeneratorODE(eqx.Module):
|
|
|
163
166
|
The minimum value of the time domain to consider
|
|
164
167
|
tmax : float
|
|
165
168
|
The maximum value of the time domain to consider
|
|
166
|
-
temporal_batch_size : int
|
|
169
|
+
temporal_batch_size : int | None, default=None
|
|
167
170
|
The size of the batch of randomly selected points among
|
|
168
|
-
the `nt` points.
|
|
171
|
+
the `nt` points. If None, no minibatches are used.
|
|
169
172
|
method : str, default="uniform"
|
|
170
173
|
Either `grid` or `uniform`, default is `uniform`.
|
|
171
174
|
The method that generates the `nt` time points. `grid` means
|
|
172
175
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
173
176
|
sampled points over the domain
|
|
174
177
|
rar_parameters : Dict[str, Int], default=None
|
|
175
|
-
|
|
176
|
-
Otherwise a dictionary with keys
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
178
|
+
Defaults to None: do not use Residual Adaptative Resampling.
|
|
179
|
+
Otherwise a dictionary with keys
|
|
180
|
+
|
|
181
|
+
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
182
|
+
- `update_every`: the number of gradient steps taken between
|
|
183
|
+
each update of collocation points in the RAR algo.
|
|
184
|
+
- `sample_size`: the size of the sample from which we will select new
|
|
185
|
+
collocation points.
|
|
186
|
+
- `selected_sample_size`: the number of selected
|
|
182
187
|
points from the sample to be added to the current collocation
|
|
183
|
-
points
|
|
184
|
-
|
|
188
|
+
points.
|
|
189
|
+
n_start : Int, default=None
|
|
185
190
|
Defaults to None. The effective size of nt used at start time.
|
|
186
191
|
This value must be
|
|
187
192
|
provided when rar_parameters is not None. Otherwise we set internally
|
|
188
|
-
|
|
189
|
-
In RAR,
|
|
193
|
+
n_start = nt and this is hidden from the user.
|
|
194
|
+
In RAR, n_start
|
|
190
195
|
then corresponds to the initial number of points we train the PINN.
|
|
191
196
|
"""
|
|
192
197
|
|
|
193
|
-
key: Key
|
|
194
|
-
nt: Int
|
|
195
|
-
tmin: Float
|
|
196
|
-
tmax: Float
|
|
197
|
-
temporal_batch_size: Int = eqx.field(static=True
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
198
|
+
key: Key = eqx.field(kw_only=True)
|
|
199
|
+
nt: Int = eqx.field(kw_only=True, static=True)
|
|
200
|
+
tmin: Float = eqx.field(kw_only=True)
|
|
201
|
+
tmax: Float = eqx.field(kw_only=True)
|
|
202
|
+
temporal_batch_size: Int | None = eqx.field(static=True, default=None, kw_only=True)
|
|
203
|
+
method: str = eqx.field(
|
|
204
|
+
static=True, kw_only=True, default_factory=lambda: "uniform"
|
|
205
|
+
)
|
|
206
|
+
rar_parameters: Dict[str, Int] = eqx.field(default=None, kw_only=True)
|
|
207
|
+
n_start: Int = eqx.field(static=True, default=None, kw_only=True)
|
|
202
208
|
|
|
203
|
-
# all the init=False fields are set in __post_init__
|
|
204
|
-
|
|
205
|
-
p_times: Float[Array, "nt"] = eqx.field(init=False)
|
|
209
|
+
# all the init=False fields are set in __post_init__
|
|
210
|
+
p: Float[Array, "nt 1"] = eqx.field(init=False)
|
|
206
211
|
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
207
212
|
rar_iter_nb: Int = eqx.field(init=False)
|
|
208
213
|
curr_time_idx: Int = eqx.field(init=False)
|
|
209
|
-
times: Float[Array, "nt"] = eqx.field(init=False)
|
|
214
|
+
times: Float[Array, "nt 1"] = eqx.field(init=False)
|
|
210
215
|
|
|
211
216
|
def __post_init__(self):
|
|
212
217
|
(
|
|
213
|
-
self.
|
|
214
|
-
self.
|
|
218
|
+
self.n_start,
|
|
219
|
+
self.p,
|
|
215
220
|
self.rar_iter_from_last_sampling,
|
|
216
221
|
self.rar_iter_nb,
|
|
217
|
-
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
222
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.nt, self.n_start)
|
|
223
|
+
|
|
224
|
+
if self.temporal_batch_size is not None:
|
|
225
|
+
self.curr_time_idx = self.nt + self.temporal_batch_size
|
|
226
|
+
# to be sure there is a shuffling at first get_batch()
|
|
227
|
+
# NOTE in the extreme case we could do:
|
|
228
|
+
# self.curr_time_idx=jnp.iinfo(jnp.int32).max - self.temporal_batch_size - 1
|
|
229
|
+
# but we do not test for such extreme values. Where we subtract
|
|
230
|
+
# self.temporal_batch_size - 1 because otherwise when computing
|
|
231
|
+
# `bend` we do not want to overflow the max int32 with unwanted behaviour
|
|
232
|
+
else:
|
|
233
|
+
self.curr_time_idx = 0
|
|
227
234
|
|
|
228
235
|
self.key, self.times = self.generate_time_data(self.key)
|
|
229
236
|
# Note that, here, in __init__ (and __post_init__), this is the
|
|
230
237
|
# only place where self assignment are authorized so we do the
|
|
231
|
-
# above way for the key.
|
|
232
|
-
# key from generate_*_data is to easily align key with legacy
|
|
233
|
-
# DataGenerators to use same unit tests
|
|
238
|
+
# above way for the key.
|
|
234
239
|
|
|
235
240
|
def sample_in_time_domain(
|
|
236
241
|
self, key: Key, sample_size: Int = None
|
|
237
|
-
) -> Float[Array, "nt"]:
|
|
242
|
+
) -> Float[Array, "nt 1"]:
|
|
238
243
|
return jax.random.uniform(
|
|
239
244
|
key,
|
|
240
|
-
(self.nt if sample_size is None else sample_size,),
|
|
245
|
+
(self.nt if sample_size is None else sample_size, 1),
|
|
241
246
|
minval=self.tmin,
|
|
242
247
|
maxval=self.tmax,
|
|
243
248
|
)
|
|
@@ -247,26 +252,26 @@ class DataGeneratorODE(eqx.Module):
|
|
|
247
252
|
Construct a complete set of `self.nt` time points according to the
|
|
248
253
|
specified `self.method`
|
|
249
254
|
|
|
250
|
-
Note that self.times has always size self.nt and not self.
|
|
255
|
+
Note that self.times has always size self.nt and not self.n_start, even
|
|
251
256
|
in RAR scheme, we must allocate all the collocation points
|
|
252
257
|
"""
|
|
253
258
|
key, subkey = jax.random.split(self.key)
|
|
254
259
|
if self.method == "grid":
|
|
255
260
|
partial_times = (self.tmax - self.tmin) / self.nt
|
|
256
|
-
return key, jnp.arange(self.tmin, self.tmax, partial_times)
|
|
261
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
257
262
|
if self.method == "uniform":
|
|
258
263
|
return key, self.sample_in_time_domain(subkey)
|
|
259
264
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
260
265
|
|
|
261
266
|
def _get_time_operands(
|
|
262
267
|
self,
|
|
263
|
-
) -> tuple[Key, Float[Array, "nt"], Int, Int, Float[Array, "nt"]]:
|
|
268
|
+
) -> tuple[Key, Float[Array, "nt 1"], Int, Int, Float[Array, "nt 1"]]:
|
|
264
269
|
return (
|
|
265
270
|
self.key,
|
|
266
271
|
self.times,
|
|
267
272
|
self.curr_time_idx,
|
|
268
273
|
self.temporal_batch_size,
|
|
269
|
-
self.
|
|
274
|
+
self.p,
|
|
270
275
|
)
|
|
271
276
|
|
|
272
277
|
def temporal_batch(
|
|
@@ -276,14 +281,18 @@ class DataGeneratorODE(eqx.Module):
|
|
|
276
281
|
Return a batch of time points. If all the batches have been seen, we
|
|
277
282
|
reshuffle them, otherwise we just return the next unseen batch.
|
|
278
283
|
"""
|
|
284
|
+
if self.temporal_batch_size is None or self.temporal_batch_size == self.nt:
|
|
285
|
+
# Avoid unnecessary reshuffling
|
|
286
|
+
return self, self.times
|
|
287
|
+
|
|
279
288
|
bstart = self.curr_time_idx
|
|
280
289
|
bend = bstart + self.temporal_batch_size
|
|
281
290
|
|
|
282
291
|
# Compute the effective number of used collocation points
|
|
283
292
|
if self.rar_parameters is not None:
|
|
284
293
|
nt_eff = (
|
|
285
|
-
self.
|
|
286
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
294
|
+
self.n_start
|
|
295
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
|
|
287
296
|
)
|
|
288
297
|
else:
|
|
289
298
|
nt_eff = self.nt
|
|
@@ -295,11 +304,11 @@ class DataGeneratorODE(eqx.Module):
|
|
|
295
304
|
|
|
296
305
|
# commands below are equivalent to
|
|
297
306
|
# return self.times[i:(i+t_batch_size)]
|
|
298
|
-
# start indices can be dynamic
|
|
307
|
+
# start indices can be dynamic but the slice shape is fixed
|
|
299
308
|
return new, jax.lax.dynamic_slice(
|
|
300
309
|
new.times,
|
|
301
|
-
start_indices=(new.curr_time_idx,),
|
|
302
|
-
slice_sizes=(new.temporal_batch_size,),
|
|
310
|
+
start_indices=(new.curr_time_idx, 0),
|
|
311
|
+
slice_sizes=(new.temporal_batch_size, 1),
|
|
303
312
|
)
|
|
304
313
|
|
|
305
314
|
def get_batch(self) -> tuple["DataGeneratorODE", ODEBatch]:
|
|
@@ -324,17 +333,17 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
324
333
|
batches. Batches are made so that each data point is seen only
|
|
325
334
|
once during 1 epoch.
|
|
326
335
|
nb : Int | None
|
|
327
|
-
The total number of points in $\partial\Omega$.
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
omega_batch_size : Int
|
|
336
|
+
The total number of points in $\partial\Omega$. Can be None if no
|
|
337
|
+
boundary condition is specified.
|
|
338
|
+
omega_batch_size : Int | None, default=None
|
|
331
339
|
The size of the batch of randomly selected points among
|
|
332
|
-
the `n` points.
|
|
333
|
-
omega_border_batch_size : Int | None
|
|
340
|
+
the `n` points. If None no minibatches are used.
|
|
341
|
+
omega_border_batch_size : Int | None, default=None
|
|
334
342
|
The size of the batch of points randomly selected
|
|
335
|
-
among the `nb` points.
|
|
336
|
-
|
|
337
|
-
|
|
343
|
+
among the `nb` points. If None, `omega_border_batch_size`
|
|
344
|
+
no minibatches are used. In dimension 1,
|
|
345
|
+
minibatches are never used since the boundary is composed of two
|
|
346
|
+
singletons.
|
|
338
347
|
dim : Int
|
|
339
348
|
Dimension of $\Omega$ domain
|
|
340
349
|
min_pts : tuple[tuple[Float, Float], ...]
|
|
@@ -351,34 +360,39 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
351
360
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
352
361
|
sampled points over the domain
|
|
353
362
|
rar_parameters : Dict[str, Int], default=None
|
|
354
|
-
|
|
355
|
-
Otherwise a dictionary with keys
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
363
|
+
Defaults to None: do not use Residual Adaptative Resampling.
|
|
364
|
+
Otherwise a dictionary with keys
|
|
365
|
+
|
|
366
|
+
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
367
|
+
- `update_every`: the number of gradient steps taken between
|
|
368
|
+
each update of collocation points in the RAR algo.
|
|
369
|
+
- `sample_size`: the size of the sample from which we will select new
|
|
370
|
+
collocation points.
|
|
371
|
+
- `selected_sample_size`: the number of selected
|
|
361
372
|
points from the sample to be added to the current collocation
|
|
362
|
-
points
|
|
373
|
+
points.
|
|
363
374
|
n_start : Int, default=None
|
|
364
375
|
Defaults to None. The effective size of n used at start time.
|
|
365
376
|
This value must be
|
|
366
377
|
provided when rar_parameters is not None. Otherwise we set internally
|
|
367
378
|
n_start = n and this is hidden from the user.
|
|
368
379
|
In RAR, n_start
|
|
369
|
-
then corresponds to the initial number of points we train the PINN.
|
|
380
|
+
then corresponds to the initial number of points we train the PINN on.
|
|
370
381
|
"""
|
|
371
382
|
|
|
372
383
|
# kw_only in base class is motivated here: https://stackoverflow.com/a/69822584
|
|
373
384
|
key: Key = eqx.field(kw_only=True)
|
|
374
|
-
n: Int = eqx.field(kw_only=True)
|
|
375
|
-
nb: Int | None = eqx.field(kw_only=True)
|
|
376
|
-
omega_batch_size: Int = eqx.field(
|
|
377
|
-
kw_only=True,
|
|
385
|
+
n: Int = eqx.field(kw_only=True, static=True)
|
|
386
|
+
nb: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
387
|
+
omega_batch_size: Int | None = eqx.field(
|
|
388
|
+
kw_only=True,
|
|
389
|
+
static=True,
|
|
390
|
+
default=None, # can be None as
|
|
391
|
+
# CubicMeshPDENonStatio inherits but also if omega_batch_size=n
|
|
378
392
|
) # static cause used as a
|
|
379
393
|
# shape in jax.lax.dynamic_slice
|
|
380
394
|
omega_border_batch_size: Int | None = eqx.field(
|
|
381
|
-
kw_only=True, static=True
|
|
395
|
+
kw_only=True, static=True, default=None
|
|
382
396
|
) # static cause used as a
|
|
383
397
|
# shape in jax.lax.dynamic_slice
|
|
384
398
|
dim: Int = eqx.field(kw_only=True, static=True) # static cause used as a
|
|
@@ -391,10 +405,8 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
391
405
|
rar_parameters: Dict[str, Int] = eqx.field(kw_only=True, default=None)
|
|
392
406
|
n_start: Int = eqx.field(kw_only=True, default=None, static=True)
|
|
393
407
|
|
|
394
|
-
# all the init=False fields are set in __post_init__
|
|
395
|
-
|
|
396
|
-
p_omega: Float[Array, "n"] = eqx.field(init=False)
|
|
397
|
-
p_border: None = eqx.field(init=False)
|
|
408
|
+
# all the init=False fields are set in __post_init__
|
|
409
|
+
p: Float[Array, "n"] = eqx.field(init=False)
|
|
398
410
|
rar_iter_from_last_sampling: Int = eqx.field(init=False)
|
|
399
411
|
rar_iter_nb: Int = eqx.field(init=False)
|
|
400
412
|
curr_omega_idx: Int = eqx.field(init=False)
|
|
@@ -410,51 +422,59 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
410
422
|
|
|
411
423
|
(
|
|
412
424
|
self.n_start,
|
|
413
|
-
self.
|
|
425
|
+
self.p,
|
|
414
426
|
self.rar_iter_from_last_sampling,
|
|
415
427
|
self.rar_iter_nb,
|
|
416
428
|
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
417
429
|
|
|
418
|
-
self.
|
|
430
|
+
if self.method == "grid" and self.dim == 2:
|
|
431
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
432
|
+
if self.n != perfect_sq:
|
|
433
|
+
warnings.warn(
|
|
434
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
435
|
+
f" perfect square dataset size (self.n = {self.n})."
|
|
436
|
+
f" Modifying self.n to self.n = {perfect_sq}."
|
|
437
|
+
)
|
|
438
|
+
self.n = perfect_sq
|
|
419
439
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
440
|
+
if self.nb is not None:
|
|
441
|
+
if self.dim == 1:
|
|
442
|
+
self.omega_border_batch_size = None
|
|
443
|
+
# We are in 1-D case => omega_border_batch_size is
|
|
444
|
+
# ignored since borders of Omega are singletons.
|
|
445
|
+
# self.border_batch() will return [xmin, xmax]
|
|
446
|
+
else:
|
|
447
|
+
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"number of border point must be"
|
|
450
|
+
f" a multiple of 2xd = {2*self.dim} (the # of faces of"
|
|
451
|
+
f" a d-dimensional cube). Got {self.nb=}."
|
|
452
|
+
)
|
|
453
|
+
if (
|
|
454
|
+
self.omega_border_batch_size is not None
|
|
455
|
+
and self.nb // (2 * self.dim) < self.omega_border_batch_size
|
|
456
|
+
):
|
|
457
|
+
raise ValueError(
|
|
458
|
+
f"number of points per facets ({self.nb//(2*self.dim)})"
|
|
459
|
+
f" cannot be lower than border batch size "
|
|
460
|
+
f" ({self.omega_border_batch_size})."
|
|
461
|
+
)
|
|
462
|
+
self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
|
|
463
|
+
|
|
464
|
+
if self.omega_batch_size is None:
|
|
465
|
+
self.curr_omega_idx = 0
|
|
433
466
|
else:
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
"number of border point must be"
|
|
437
|
-
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
438
|
-
)
|
|
439
|
-
if self.nb // (2 * self.dim) < self.omega_border_batch_size:
|
|
440
|
-
raise ValueError(
|
|
441
|
-
"number of points per facets (nb//2*self.dim)"
|
|
442
|
-
" cannot be lower than border batch size"
|
|
443
|
-
)
|
|
444
|
-
self.nb = int((2 * self.dim) * (self.nb // (2 * self.dim)))
|
|
467
|
+
self.curr_omega_idx = self.n + self.omega_batch_size
|
|
468
|
+
# to be sure there is a shuffling at first get_batch()
|
|
445
469
|
|
|
446
|
-
self.curr_omega_idx = jnp.iinfo(jnp.int32).max - self.omega_batch_size - 1
|
|
447
|
-
# see explaination in DataGeneratorODE
|
|
448
470
|
if self.omega_border_batch_size is None:
|
|
449
|
-
self.curr_omega_border_idx =
|
|
471
|
+
self.curr_omega_border_idx = 0
|
|
450
472
|
else:
|
|
451
|
-
self.curr_omega_border_idx =
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
self.key, self.omega, self.omega_border = self.generate_data(self.key)
|
|
457
|
-
# see explaination in DataGeneratorODE for the key
|
|
473
|
+
self.curr_omega_border_idx = self.nb + self.omega_border_batch_size
|
|
474
|
+
# to be sure there is a shuffling at first get_batch()
|
|
475
|
+
|
|
476
|
+
self.key, self.omega = self.generate_omega_data(self.key)
|
|
477
|
+
self.key, self.omega_border = self.generate_omega_border_data(self.key)
|
|
458
478
|
|
|
459
479
|
def sample_in_omega_domain(
|
|
460
480
|
self, keys: Key, sample_size: Int = None
|
|
@@ -480,9 +500,10 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
480
500
|
)
|
|
481
501
|
|
|
482
502
|
def sample_in_omega_border_domain(
|
|
483
|
-
self, keys: Key
|
|
503
|
+
self, keys: Key, sample_size: int = None
|
|
484
504
|
) -> Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None:
|
|
485
|
-
|
|
505
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
506
|
+
if sample_size is None:
|
|
486
507
|
return None
|
|
487
508
|
if self.dim == 1:
|
|
488
509
|
xmin = self.min_pts[0]
|
|
@@ -492,8 +513,7 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
492
513
|
# currently hard-coded the 4 edges for d==2
|
|
493
514
|
# TODO : find a general & efficient way to sample from the border
|
|
494
515
|
# (facets) of the hypercube in general dim.
|
|
495
|
-
|
|
496
|
-
facet_n = self.nb // (2 * self.dim)
|
|
516
|
+
facet_n = sample_size // (2 * self.dim)
|
|
497
517
|
xmin = jnp.hstack(
|
|
498
518
|
[
|
|
499
519
|
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
@@ -544,54 +564,64 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
544
564
|
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
545
565
|
)
|
|
546
566
|
|
|
547
|
-
def
|
|
567
|
+
def generate_omega_data(self, key: Key, data_size: int = None) -> tuple[
|
|
548
568
|
Key,
|
|
549
569
|
Float[Array, "n dim"],
|
|
550
|
-
Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
|
|
551
570
|
]:
|
|
552
571
|
r"""
|
|
553
572
|
Construct a complete set of `self.n` $\Omega$ points according to the
|
|
554
|
-
specified `self.method`.
|
|
555
|
-
$\partial\Omega$ points if `self.omega_border_batch_size` is not
|
|
556
|
-
`None`. If the latter is `None` we set `self.omega_border` to `None`.
|
|
573
|
+
specified `self.method`.
|
|
557
574
|
"""
|
|
575
|
+
data_size = self.n if data_size is None else data_size
|
|
558
576
|
# Generate Omega
|
|
559
577
|
if self.method == "grid":
|
|
560
578
|
if self.dim == 1:
|
|
561
579
|
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
omega = jnp.arange(xmin, xmax, partial)[:, None]
|
|
580
|
+
## shape (n, 1)
|
|
581
|
+
omega = jnp.linspace(xmin, xmax, data_size)[:, None]
|
|
565
582
|
else:
|
|
566
|
-
partials = [
|
|
567
|
-
(self.max_pts[i] - self.min_pts[i]) / jnp.sqrt(self.n)
|
|
568
|
-
for i in range(self.dim)
|
|
569
|
-
]
|
|
570
583
|
xyz_ = jnp.meshgrid(
|
|
571
584
|
*[
|
|
572
|
-
jnp.
|
|
585
|
+
jnp.linspace(
|
|
586
|
+
self.min_pts[i],
|
|
587
|
+
self.max_pts[i],
|
|
588
|
+
int(jnp.round(jnp.sqrt(data_size))),
|
|
589
|
+
)
|
|
573
590
|
for i in range(self.dim)
|
|
574
591
|
]
|
|
575
592
|
)
|
|
576
|
-
xyz_ = [a.reshape((
|
|
593
|
+
xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
|
|
577
594
|
omega = jnp.concatenate(xyz_, axis=-1)
|
|
578
595
|
elif self.method == "uniform":
|
|
579
596
|
if self.dim == 1:
|
|
580
597
|
key, subkeys = jax.random.split(key, 2)
|
|
581
598
|
else:
|
|
582
599
|
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
583
|
-
omega = self.sample_in_omega_domain(subkeys)
|
|
600
|
+
omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
|
|
584
601
|
else:
|
|
585
602
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
603
|
+
return key, omega
|
|
586
604
|
|
|
605
|
+
def generate_omega_border_data(self, key: Key, data_size: int = None) -> tuple[
|
|
606
|
+
Key,
|
|
607
|
+
Float[Array, "1 2"] | Float[Array, "(nb//4) 2 4"] | None,
|
|
608
|
+
]:
|
|
609
|
+
r"""
|
|
610
|
+
Also constructs a complete set of `self.nb`
|
|
611
|
+
$\partial\Omega$ points if `self.omega_border_batch_size` is not
|
|
612
|
+
`None`. If the latter is `None` we set `self.omega_border` to `None`.
|
|
613
|
+
"""
|
|
587
614
|
# Generate border of omega
|
|
588
|
-
|
|
615
|
+
data_size = self.nb if data_size is None else data_size
|
|
616
|
+
if self.dim == 2:
|
|
589
617
|
key, *subkeys = jax.random.split(key, 5)
|
|
590
618
|
else:
|
|
591
619
|
subkeys = None
|
|
592
|
-
omega_border = self.sample_in_omega_border_domain(
|
|
620
|
+
omega_border = self.sample_in_omega_border_domain(
|
|
621
|
+
subkeys, sample_size=data_size
|
|
622
|
+
)
|
|
593
623
|
|
|
594
|
-
return key,
|
|
624
|
+
return key, omega_border
|
|
595
625
|
|
|
596
626
|
def _get_omega_operands(
|
|
597
627
|
self,
|
|
@@ -601,7 +631,7 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
601
631
|
self.omega,
|
|
602
632
|
self.curr_omega_idx,
|
|
603
633
|
self.omega_batch_size,
|
|
604
|
-
self.
|
|
634
|
+
self.p,
|
|
605
635
|
)
|
|
606
636
|
|
|
607
637
|
def inside_batch(
|
|
@@ -612,11 +642,15 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
612
642
|
If all the batches have been seen, we reshuffle them,
|
|
613
643
|
otherwise we just return the next unseen batch.
|
|
614
644
|
"""
|
|
645
|
+
if self.omega_batch_size is None or self.omega_batch_size == self.n:
|
|
646
|
+
# Avoid unnecessary reshuffling
|
|
647
|
+
return self, self.omega
|
|
648
|
+
|
|
615
649
|
# Compute the effective number of used collocation points
|
|
616
650
|
if self.rar_parameters is not None:
|
|
617
651
|
n_eff = (
|
|
618
652
|
self.n_start
|
|
619
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
653
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
|
|
620
654
|
)
|
|
621
655
|
else:
|
|
622
656
|
n_eff = self.n
|
|
@@ -645,7 +679,7 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
645
679
|
self.omega_border,
|
|
646
680
|
self.curr_omega_border_idx,
|
|
647
681
|
self.omega_border_batch_size,
|
|
648
|
-
|
|
682
|
+
None,
|
|
649
683
|
)
|
|
650
684
|
|
|
651
685
|
def border_batch(
|
|
@@ -670,12 +704,23 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
670
704
|
|
|
671
705
|
|
|
672
706
|
"""
|
|
673
|
-
if self.
|
|
707
|
+
if self.nb is None:
|
|
708
|
+
# Avoid unnecessary reshuffling
|
|
674
709
|
return self, None
|
|
710
|
+
|
|
675
711
|
if self.dim == 1:
|
|
712
|
+
# Avoid unnecessary reshuffling
|
|
676
713
|
# 1-D case, no randomness : we always return the whole omega border,
|
|
677
714
|
# i.e. (1, 1, 2) shape jnp.array([[[xmin], [xmax]]]).
|
|
678
715
|
return self, self.omega_border[None, None] # shape is (1, 1, 2)
|
|
716
|
+
|
|
717
|
+
if (
|
|
718
|
+
self.omega_border_batch_size is None
|
|
719
|
+
or self.omega_border_batch_size == self.nb // 2**self.dim
|
|
720
|
+
):
|
|
721
|
+
# Avoid unnecessary reshuffling
|
|
722
|
+
return self, self.omega_border
|
|
723
|
+
|
|
679
724
|
bstart = self.curr_omega_border_idx
|
|
680
725
|
bend = bstart + self.omega_border_batch_size
|
|
681
726
|
|
|
@@ -701,7 +746,7 @@ class CubicMeshPDEStatio(eqx.Module):
|
|
|
701
746
|
"""
|
|
702
747
|
new, inside_batch = self.inside_batch()
|
|
703
748
|
new, border_batch = new.border_batch()
|
|
704
|
-
return new, PDEStatioBatch(
|
|
749
|
+
return new, PDEStatioBatch(domain_batch=inside_batch, border_batch=border_batch)
|
|
705
750
|
|
|
706
751
|
|
|
707
752
|
class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
@@ -715,30 +760,29 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
715
760
|
key : Key
|
|
716
761
|
Jax random key to sample new time points and to shuffle batches
|
|
717
762
|
n : Int
|
|
718
|
-
The number of total
|
|
763
|
+
The number of total $I\times \Omega$ points that will be divided in
|
|
719
764
|
batches. Batches are made so that each data point is seen only
|
|
720
765
|
once during 1 epoch.
|
|
721
766
|
nb : Int | None
|
|
722
|
-
The total number of points in $\partial\Omega$.
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
The number of total time points that will be divided in
|
|
767
|
+
The total number of points in $\partial\Omega$. Can be None if no
|
|
768
|
+
boundary condition is specified.
|
|
769
|
+
ni : Int
|
|
770
|
+
The number of total $\Omega$ points at $t=0$ that will be divided in
|
|
727
771
|
batches. Batches are made so that each data point is seen only
|
|
728
772
|
once during 1 epoch.
|
|
729
|
-
|
|
773
|
+
domain_batch_size : Int | None, default=None
|
|
730
774
|
The size of the batch of randomly selected points among
|
|
731
|
-
the `n` points.
|
|
732
|
-
|
|
775
|
+
the `n` points. If None no mini-batches are used.
|
|
776
|
+
border_batch_size : Int | None, default=None
|
|
733
777
|
The size of the batch of points randomly selected
|
|
734
|
-
among the `nb` points.
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
temporal_batch_size : Int
|
|
778
|
+
among the `nb` points. If None, `domain_batch_size` no
|
|
779
|
+
mini-batches are used.
|
|
780
|
+
initial_batch_size : Int | None, default=None
|
|
738
781
|
The size of the batch of randomly selected points among
|
|
739
|
-
the `
|
|
782
|
+
the `ni` points. If None no
|
|
783
|
+
mini-batches are used.
|
|
740
784
|
dim : Int
|
|
741
|
-
An integer.
|
|
785
|
+
An integer. Dimension of $\Omega$ domain.
|
|
742
786
|
min_pts : tuple[tuple[Float, Float], ...]
|
|
743
787
|
A tuple of minimum values of the domain along each dimension. For a sampling
|
|
744
788
|
in `n` dimension, this represents $(x_{1, min}, x_{2,min}, ...,
|
|
@@ -757,13 +801,15 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
757
801
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
758
802
|
sampled points over the domain
|
|
759
803
|
rar_parameters : Dict[str, Int], default=None
|
|
760
|
-
|
|
761
|
-
Otherwise a dictionary with keys
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
804
|
+
Defaults to None: do not use Residual Adaptative Resampling.
|
|
805
|
+
Otherwise a dictionary with keys
|
|
806
|
+
|
|
807
|
+
- `start_iter`: the iteration at which we start the RAR sampling scheme (we first have a "burn-in" period).
|
|
808
|
+
- `update_every`: the number of gradient steps taken between
|
|
809
|
+
each update of collocation points in the RAR algo.
|
|
810
|
+
- `sample_size`: the size of the sample from which we will select new
|
|
811
|
+
collocation points.
|
|
812
|
+
- `selected_sample_size`: the number of selected
|
|
767
813
|
points from the sample to be added to the current collocation
|
|
768
814
|
points.
|
|
769
815
|
n_start : Int, default=None
|
|
@@ -773,27 +819,23 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
773
819
|
n_start = n and this is hidden from the user.
|
|
774
820
|
In RAR, n_start
|
|
775
821
|
then corresponds to the initial number of omega points we train the PINN.
|
|
776
|
-
nt_start : Int, default=None
|
|
777
|
-
Defaults to None. A RAR hyper-parameter. Same as ``n_start`` but
|
|
778
|
-
for times collocation point. See also ``DataGeneratorODE``
|
|
779
|
-
documentation.
|
|
780
|
-
cartesian_product : Bool, default=True
|
|
781
|
-
Defaults to True. Whether we return the cartesian product of the
|
|
782
|
-
temporal batch with the inside and border batches. If False we just
|
|
783
|
-
return their concatenation.
|
|
784
822
|
"""
|
|
785
823
|
|
|
786
|
-
temporal_batch_size: Int = eqx.field(kw_only=True)
|
|
787
824
|
tmin: Float = eqx.field(kw_only=True)
|
|
788
825
|
tmax: Float = eqx.field(kw_only=True)
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
826
|
+
ni: Int = eqx.field(kw_only=True, static=True)
|
|
827
|
+
domain_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
828
|
+
initial_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
829
|
+
border_batch_size: Int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
830
|
+
|
|
831
|
+
curr_domain_idx: Int = eqx.field(init=False)
|
|
832
|
+
curr_initial_idx: Int = eqx.field(init=False)
|
|
833
|
+
curr_border_idx: Int = eqx.field(init=False)
|
|
834
|
+
domain: Float[Array, "n 1+dim"] = eqx.field(init=False)
|
|
835
|
+
border: Float[Array, "(nb//2) 1+1 2"] | Float[Array, "(nb//4) 2+1 4"] | None = (
|
|
836
|
+
eqx.field(init=False)
|
|
837
|
+
)
|
|
838
|
+
initial: Float[Array, "ni dim"] = eqx.field(init=False)
|
|
797
839
|
|
|
798
840
|
def __post_init__(self):
|
|
799
841
|
"""
|
|
@@ -803,162 +845,310 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
803
845
|
super().__post_init__() # because __init__ or __post_init__ of Base
|
|
804
846
|
# class is not automatically called
|
|
805
847
|
|
|
806
|
-
if
|
|
807
|
-
|
|
848
|
+
if self.method == "grid":
|
|
849
|
+
# NOTE we must redo the sampling with the square root number of samples
|
|
850
|
+
# and then take the cartesian product
|
|
851
|
+
self.n = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
852
|
+
if self.dim == 2:
|
|
853
|
+
# in the case of grid sampling in 2D in dim 2 in non-statio,
|
|
854
|
+
# self.n needs to be a perfect ^4, because there is the
|
|
855
|
+
# cartesian product with time domain which is also present
|
|
856
|
+
perfect_4 = int(jnp.round(self.n**0.25) ** 4)
|
|
857
|
+
if self.n != perfect_4:
|
|
858
|
+
warnings.warn(
|
|
859
|
+
"Grid sampling is requested in dimension 2 in non"
|
|
860
|
+
" stationary setting with a non"
|
|
861
|
+
f" perfect square dataset size (self.n = {self.n})."
|
|
862
|
+
f" Modifying self.n to self.n = {perfect_4}."
|
|
863
|
+
)
|
|
864
|
+
self.n = perfect_4
|
|
865
|
+
self.key, half_domain_times = self.generate_time_data(
|
|
866
|
+
self.key, int(jnp.round(jnp.sqrt(self.n)))
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
self.key, half_domain_omega = self.generate_omega_data(
|
|
870
|
+
self.key, data_size=int(jnp.round(jnp.sqrt(self.n)))
|
|
871
|
+
)
|
|
872
|
+
self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
|
|
873
|
+
|
|
874
|
+
# NOTE
|
|
875
|
+
(
|
|
876
|
+
self.n_start,
|
|
877
|
+
self.p,
|
|
878
|
+
self.rar_iter_from_last_sampling,
|
|
879
|
+
self.rar_iter_nb,
|
|
880
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, self.n_start)
|
|
881
|
+
elif self.method == "uniform":
|
|
882
|
+
self.key, domain_times = self.generate_time_data(self.key, self.n)
|
|
883
|
+
self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
|
|
884
|
+
else:
|
|
885
|
+
raise ValueError(
|
|
886
|
+
f"Bad value for method. Got {self.method}, expected"
|
|
887
|
+
' "grid" or "uniform"'
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
if self.domain_batch_size is None:
|
|
891
|
+
self.curr_domain_idx = 0
|
|
892
|
+
else:
|
|
893
|
+
self.curr_domain_idx = self.n + self.domain_batch_size
|
|
894
|
+
# to be sure there is a shuffling at first get_batch()
|
|
895
|
+
if self.nb is not None:
|
|
896
|
+
# the check below has already been done in super.__post_init__ if
|
|
897
|
+
# dim > 1. Here we retest it in whatever dim
|
|
898
|
+
if self.nb % (2 * self.dim) != 0 or self.nb < 2 * self.dim:
|
|
808
899
|
raise ValueError(
|
|
809
|
-
"
|
|
810
|
-
"
|
|
811
|
-
"must then be equal to self.omega_batch_size"
|
|
900
|
+
"number of border point must be"
|
|
901
|
+
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
812
902
|
)
|
|
903
|
+
# the check below concern omega_border_batch_size for dim > 1 in
|
|
904
|
+
# super.__post_init__. Here it concerns all dim values since our
|
|
905
|
+
# border_batch is the concatenation or cartesian product with times
|
|
813
906
|
if (
|
|
814
|
-
self.
|
|
815
|
-
and self.
|
|
816
|
-
and self.temporal_batch_size != self.omega_border_batch_size
|
|
907
|
+
self.border_batch_size is not None
|
|
908
|
+
and self.nb // (2 * self.dim) < self.border_batch_size
|
|
817
909
|
):
|
|
818
910
|
raise ValueError(
|
|
819
|
-
"
|
|
820
|
-
"
|
|
821
|
-
|
|
911
|
+
"number of points per facets (nb//2*self.dim)"
|
|
912
|
+
" cannot be lower than border batch size"
|
|
913
|
+
)
|
|
914
|
+
self.key, boundary_times = self.generate_time_data(
|
|
915
|
+
self.key, self.nb // (2 * self.dim)
|
|
916
|
+
)
|
|
917
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
918
|
+
boundary_times = jnp.repeat(
|
|
919
|
+
boundary_times, self.omega_border.shape[-1], axis=2
|
|
920
|
+
)
|
|
921
|
+
if self.dim == 1:
|
|
922
|
+
self.border = make_cartesian_product(
|
|
923
|
+
boundary_times, self.omega_border[None, None]
|
|
822
924
|
)
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
925
|
+
else:
|
|
926
|
+
self.border = jnp.concatenate(
|
|
927
|
+
[boundary_times, self.omega_border], axis=1
|
|
928
|
+
)
|
|
929
|
+
if self.border_batch_size is None:
|
|
930
|
+
self.curr_border_idx = 0
|
|
931
|
+
else:
|
|
932
|
+
self.curr_border_idx = self.nb + self.border_batch_size
|
|
933
|
+
# to be sure there is a shuffling at first get_batch()
|
|
828
934
|
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
self.
|
|
832
|
-
self.
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
935
|
+
else:
|
|
936
|
+
self.border = None
|
|
937
|
+
self.curr_border_idx = None
|
|
938
|
+
self.border_batch_size = None
|
|
939
|
+
|
|
940
|
+
if self.ni is not None:
|
|
941
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
942
|
+
if self.ni != perfect_sq:
|
|
943
|
+
warnings.warn(
|
|
944
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
945
|
+
f" perfect square dataset size (self.ni = {self.ni})."
|
|
946
|
+
f" Modifying self.ni to self.ni = {perfect_sq}."
|
|
947
|
+
)
|
|
948
|
+
self.ni = perfect_sq
|
|
949
|
+
self.key, self.initial = self.generate_omega_data(
|
|
950
|
+
self.key, data_size=self.ni
|
|
951
|
+
)
|
|
836
952
|
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
953
|
+
if self.initial_batch_size is None or self.initial_batch_size == self.ni:
|
|
954
|
+
self.curr_initial_idx = 0
|
|
955
|
+
else:
|
|
956
|
+
self.curr_initial_idx = self.ni + self.initial_batch_size
|
|
957
|
+
# to be sure there is a shuffling at first get_batch()
|
|
958
|
+
else:
|
|
959
|
+
self.initial = None
|
|
960
|
+
self.initial_batch_size = None
|
|
961
|
+
self.curr_initial_idx = None
|
|
842
962
|
|
|
843
|
-
|
|
844
|
-
self
|
|
845
|
-
|
|
963
|
+
# the following attributes will not be used anymore
|
|
964
|
+
self.omega = None
|
|
965
|
+
self.omega_border = None
|
|
966
|
+
|
|
967
|
+
def generate_time_data(self, key: Key, nt: Int) -> tuple[Key, Float[Array, "nt 1"]]:
|
|
968
|
+
"""
|
|
969
|
+
Construct a complete set of `nt` time points according to the
|
|
970
|
+
specified `self.method`
|
|
971
|
+
"""
|
|
972
|
+
key, subkey = jax.random.split(key, 2)
|
|
973
|
+
if self.method == "grid":
|
|
974
|
+
partial_times = (self.tmax - self.tmin) / nt
|
|
975
|
+
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
976
|
+
if self.method == "uniform":
|
|
977
|
+
return key, self.sample_in_time_domain(subkey, nt)
|
|
978
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
979
|
+
|
|
980
|
+
def sample_in_time_domain(self, key: Key, nt: Int) -> Float[Array, "nt 1"]:
|
|
846
981
|
return jax.random.uniform(
|
|
847
982
|
key,
|
|
848
|
-
(
|
|
983
|
+
(nt, 1),
|
|
849
984
|
minval=self.tmin,
|
|
850
985
|
maxval=self.tmax,
|
|
851
986
|
)
|
|
852
987
|
|
|
853
|
-
def
|
|
988
|
+
def _get_domain_operands(
|
|
854
989
|
self,
|
|
855
|
-
) -> tuple[Key, Float[Array, "
|
|
990
|
+
) -> tuple[Key, Float[Array, "n 1+dim"], Int, Int, None]:
|
|
856
991
|
return (
|
|
857
992
|
self.key,
|
|
858
|
-
self.
|
|
859
|
-
self.
|
|
860
|
-
self.
|
|
861
|
-
self.
|
|
993
|
+
self.domain,
|
|
994
|
+
self.curr_domain_idx,
|
|
995
|
+
self.domain_batch_size,
|
|
996
|
+
self.p,
|
|
862
997
|
)
|
|
863
998
|
|
|
864
|
-
def
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
specified `self.method`
|
|
999
|
+
def domain_batch(
|
|
1000
|
+
self,
|
|
1001
|
+
) -> tuple["CubicMeshPDEStatio", Float[Array, "domain_batch_size 1+dim"]]:
|
|
868
1002
|
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
key, subkey = jax.random.split(key, 2)
|
|
873
|
-
if self.method == "grid":
|
|
874
|
-
partial_times = (self.tmax - self.tmin) / self.nt
|
|
875
|
-
return key, jnp.arange(self.tmin, self.tmax, partial_times)
|
|
876
|
-
if self.method == "uniform":
|
|
877
|
-
return key, self.sample_in_time_domain(subkey)
|
|
878
|
-
raise ValueError("Method " + self.method + " is not implemented.")
|
|
1003
|
+
if self.domain_batch_size is None or self.domain_batch_size == self.n:
|
|
1004
|
+
# Avoid unnecessary reshuffling
|
|
1005
|
+
return self, self.domain
|
|
879
1006
|
|
|
880
|
-
|
|
881
|
-
self
|
|
882
|
-
) -> tuple["CubicMeshPDENonStatio", Float[Array, "temporal_batch_size"]]:
|
|
883
|
-
"""
|
|
884
|
-
Return a batch of time points. If all the batches have been seen, we
|
|
885
|
-
reshuffle them, otherwise we just return the next unseen batch.
|
|
886
|
-
"""
|
|
887
|
-
bstart = self.curr_time_idx
|
|
888
|
-
bend = bstart + self.temporal_batch_size
|
|
1007
|
+
bstart = self.curr_domain_idx
|
|
1008
|
+
bend = bstart + self.domain_batch_size
|
|
889
1009
|
|
|
890
1010
|
# Compute the effective number of used collocation points
|
|
891
1011
|
if self.rar_parameters is not None:
|
|
892
|
-
|
|
893
|
-
self.
|
|
894
|
-
+ self.rar_iter_nb * self.rar_parameters["
|
|
1012
|
+
n_eff = (
|
|
1013
|
+
self.n_start
|
|
1014
|
+
+ self.rar_iter_nb * self.rar_parameters["selected_sample_size"]
|
|
895
1015
|
)
|
|
896
1016
|
else:
|
|
897
|
-
|
|
1017
|
+
n_eff = self.n
|
|
898
1018
|
|
|
899
|
-
new_attributes = _reset_or_increment(bend,
|
|
1019
|
+
new_attributes = _reset_or_increment(bend, n_eff, self._get_domain_operands())
|
|
900
1020
|
new = eqx.tree_at(
|
|
901
|
-
lambda m: (m.key, m.
|
|
1021
|
+
lambda m: (m.key, m.domain, m.curr_domain_idx),
|
|
1022
|
+
self,
|
|
1023
|
+
new_attributes,
|
|
1024
|
+
)
|
|
1025
|
+
return new, jax.lax.dynamic_slice(
|
|
1026
|
+
new.domain,
|
|
1027
|
+
start_indices=(new.curr_domain_idx, 0),
|
|
1028
|
+
slice_sizes=(new.domain_batch_size, new.dim + 1),
|
|
1029
|
+
)
|
|
1030
|
+
|
|
1031
|
+
def _get_border_operands(
|
|
1032
|
+
self,
|
|
1033
|
+
) -> tuple[
|
|
1034
|
+
Key, Float[Array, "nb 1+1 2"] | Float[Array, "(nb//4) 2+1 4"], Int, Int, None
|
|
1035
|
+
]:
|
|
1036
|
+
return (
|
|
1037
|
+
self.key,
|
|
1038
|
+
self.border,
|
|
1039
|
+
self.curr_border_idx,
|
|
1040
|
+
self.border_batch_size,
|
|
1041
|
+
None,
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
def border_batch(
|
|
1045
|
+
self,
|
|
1046
|
+
) -> tuple[
|
|
1047
|
+
"CubicMeshPDENonStatio",
|
|
1048
|
+
Float[Array, "border_batch_size 1+1 2"]
|
|
1049
|
+
| Float[Array, "border_batch_size 2+1 4"]
|
|
1050
|
+
| None,
|
|
1051
|
+
]:
|
|
1052
|
+
if self.nb is None:
|
|
1053
|
+
# Avoid unnecessary reshuffling
|
|
1054
|
+
return self, None
|
|
1055
|
+
|
|
1056
|
+
if (
|
|
1057
|
+
self.border_batch_size is None
|
|
1058
|
+
or self.border_batch_size == self.nb // 2**self.dim
|
|
1059
|
+
):
|
|
1060
|
+
# Avoid unnecessary reshuffling
|
|
1061
|
+
return self, self.border
|
|
1062
|
+
|
|
1063
|
+
bstart = self.curr_border_idx
|
|
1064
|
+
bend = bstart + self.border_batch_size
|
|
1065
|
+
|
|
1066
|
+
n_eff = self.border.shape[0]
|
|
1067
|
+
|
|
1068
|
+
new_attributes = _reset_or_increment(bend, n_eff, self._get_border_operands())
|
|
1069
|
+
new = eqx.tree_at(
|
|
1070
|
+
lambda m: (m.key, m.border, m.curr_border_idx),
|
|
1071
|
+
self,
|
|
1072
|
+
new_attributes,
|
|
902
1073
|
)
|
|
903
1074
|
|
|
904
1075
|
return new, jax.lax.dynamic_slice(
|
|
905
|
-
new.
|
|
906
|
-
start_indices=(new.
|
|
907
|
-
slice_sizes=(
|
|
1076
|
+
new.border,
|
|
1077
|
+
start_indices=(new.curr_border_idx, 0, 0),
|
|
1078
|
+
slice_sizes=(
|
|
1079
|
+
new.border_batch_size,
|
|
1080
|
+
new.dim + 1,
|
|
1081
|
+
2 * new.dim,
|
|
1082
|
+
),
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
def _get_initial_operands(
|
|
1086
|
+
self,
|
|
1087
|
+
) -> tuple[Key, Float[Array, "ni dim"], Int, Int, None]:
|
|
1088
|
+
return (
|
|
1089
|
+
self.key,
|
|
1090
|
+
self.initial,
|
|
1091
|
+
self.curr_initial_idx,
|
|
1092
|
+
self.initial_batch_size,
|
|
1093
|
+
None,
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
def initial_batch(
|
|
1097
|
+
self,
|
|
1098
|
+
) -> tuple["CubicMeshPDEStatio", Float[Array, "initial_batch_size dim"]]:
|
|
1099
|
+
if self.initial_batch_size is None or self.initial_batch_size == self.ni:
|
|
1100
|
+
# Avoid unnecessary reshuffling
|
|
1101
|
+
return self, self.initial
|
|
1102
|
+
|
|
1103
|
+
bstart = self.curr_initial_idx
|
|
1104
|
+
bend = bstart + self.initial_batch_size
|
|
1105
|
+
|
|
1106
|
+
n_eff = self.ni
|
|
1107
|
+
|
|
1108
|
+
new_attributes = _reset_or_increment(bend, n_eff, self._get_initial_operands())
|
|
1109
|
+
new = eqx.tree_at(
|
|
1110
|
+
lambda m: (m.key, m.initial, m.curr_initial_idx),
|
|
1111
|
+
self,
|
|
1112
|
+
new_attributes,
|
|
1113
|
+
)
|
|
1114
|
+
return new, jax.lax.dynamic_slice(
|
|
1115
|
+
new.initial,
|
|
1116
|
+
start_indices=(new.curr_initial_idx, 0),
|
|
1117
|
+
slice_sizes=(new.initial_batch_size, new.dim),
|
|
908
1118
|
)
|
|
909
1119
|
|
|
910
1120
|
def get_batch(self) -> tuple["CubicMeshPDENonStatio", PDENonStatioBatch]:
|
|
911
1121
|
"""
|
|
912
|
-
Generic method to return a batch. Here we call `self.
|
|
913
|
-
`self.border_batch()` and `self.
|
|
1122
|
+
Generic method to return a batch. Here we call `self.domain_batch()`,
|
|
1123
|
+
`self.border_batch()` and `self.initial_batch()`
|
|
914
1124
|
"""
|
|
915
|
-
new,
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
t = t.reshape(new.temporal_batch_size, 1)
|
|
919
|
-
|
|
920
|
-
if new.cartesian_product:
|
|
921
|
-
t_x = make_cartesian_product(t, x)
|
|
1125
|
+
new, domain = self.domain_batch()
|
|
1126
|
+
if self.border is not None:
|
|
1127
|
+
new, border = new.border_batch()
|
|
922
1128
|
else:
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
t_ = t.reshape(new.temporal_batch_size, 1, 1)
|
|
927
|
-
t_ = jnp.repeat(t_, dx.shape[-1], axis=2)
|
|
928
|
-
if new.cartesian_product or new.dim == 1:
|
|
929
|
-
t_dx = make_cartesian_product(t_, dx)
|
|
930
|
-
else:
|
|
931
|
-
t_dx = jnp.concatenate([t_, dx], axis=1)
|
|
1129
|
+
border = None
|
|
1130
|
+
if self.initial is not None:
|
|
1131
|
+
new, initial = new.initial_batch()
|
|
932
1132
|
else:
|
|
933
|
-
|
|
1133
|
+
initial = None
|
|
934
1134
|
|
|
935
1135
|
return new, PDENonStatioBatch(
|
|
936
|
-
|
|
1136
|
+
domain_batch=domain, border_batch=border, initial_batch=initial
|
|
937
1137
|
)
|
|
938
1138
|
|
|
939
1139
|
|
|
940
1140
|
class DataGeneratorObservations(eqx.Module):
|
|
941
1141
|
r"""
|
|
942
|
-
Despite the class name, it is rather a dataloader
|
|
943
|
-
observations
|
|
1142
|
+
Despite the class name, it is rather a dataloader for user-provided
|
|
1143
|
+
observations which will are used in the observations loss.
|
|
944
1144
|
|
|
945
1145
|
Parameters
|
|
946
1146
|
----------
|
|
947
1147
|
key : Key
|
|
948
1148
|
Jax random key to shuffle batches
|
|
949
|
-
obs_batch_size : Int
|
|
1149
|
+
obs_batch_size : Int | None
|
|
950
1150
|
The size of the batch of randomly selected points among
|
|
951
|
-
the `n` points.
|
|
952
|
-
elements of the return observation dict batch.
|
|
953
|
-
NOTE: no check is done BUT users should be careful that
|
|
954
|
-
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
955
|
-
`omega_batch_size` or the product of both. In the first case, the
|
|
956
|
-
present DataGeneratorObservations instance complements an ODEBatch,
|
|
957
|
-
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
958
|
-
= False). In the second case, `obs_batch_size` =
|
|
959
|
-
`temporal_batch_size * omega_batch_size` if the present
|
|
960
|
-
DataGeneratorParameter complements a PDENonStatioBatch
|
|
961
|
-
with self.cartesian_product = True
|
|
1151
|
+
the `n` points. If None, no minibatch are used.
|
|
962
1152
|
observed_pinn_in : Float[Array, "n_obs nb_pinn_in"]
|
|
963
1153
|
Observed values corresponding to the input of the PINN
|
|
964
1154
|
(eg. the time at which we recorded the observations). The first
|
|
@@ -984,11 +1174,11 @@ class DataGeneratorObservations(eqx.Module):
|
|
|
984
1174
|
datasets of observations. Note that computations for **batches**
|
|
985
1175
|
can still be performed on other devices (*e.g.* GPU, TPU or
|
|
986
1176
|
any pre-defined Sharding) thanks to the `obs_batch_sharding`
|
|
987
|
-
arguments of `jinns.solve()`. Read
|
|
1177
|
+
arguments of `jinns.solve()`. Read `jinns.solve()` doc for more info.
|
|
988
1178
|
"""
|
|
989
1179
|
|
|
990
1180
|
key: Key
|
|
991
|
-
obs_batch_size: Int = eqx.field(static=True)
|
|
1181
|
+
obs_batch_size: Int | None = eqx.field(static=True)
|
|
992
1182
|
observed_pinn_in: Float[Array, "n_obs nb_pinn_in"]
|
|
993
1183
|
observed_values: Float[Array, "n_obs nb_pinn_out"]
|
|
994
1184
|
observed_eq_params: Dict[str, Float[Array, "n_obs 1"]] = eqx.field(
|
|
@@ -996,7 +1186,7 @@ class DataGeneratorObservations(eqx.Module):
|
|
|
996
1186
|
)
|
|
997
1187
|
sharding_device: jax.sharding.Sharding = eqx.field(static=True, default=None)
|
|
998
1188
|
|
|
999
|
-
n: Int = eqx.field(init=False)
|
|
1189
|
+
n: Int = eqx.field(init=False, static=True)
|
|
1000
1190
|
curr_idx: Int = eqx.field(init=False)
|
|
1001
1191
|
indices: Array = eqx.field(init=False)
|
|
1002
1192
|
|
|
@@ -1040,7 +1230,11 @@ class DataGeneratorObservations(eqx.Module):
|
|
|
1040
1230
|
self.observed_eq_params, self.sharding_device
|
|
1041
1231
|
)
|
|
1042
1232
|
|
|
1043
|
-
|
|
1233
|
+
if self.obs_batch_size is not None:
|
|
1234
|
+
self.curr_idx = self.n + self.obs_batch_size
|
|
1235
|
+
# to be sure there is a shuffling at first get_batch()
|
|
1236
|
+
else:
|
|
1237
|
+
self.curr_idx = 0
|
|
1044
1238
|
# For speed and to avoid duplicating data what is really
|
|
1045
1239
|
# shuffled is a vector of indices
|
|
1046
1240
|
if self.sharding_device is not None:
|
|
@@ -1079,6 +1273,13 @@ class DataGeneratorObservations(eqx.Module):
|
|
|
1079
1273
|
observed_pinn_in, observed_values, etc. are dictionaries with keys
|
|
1080
1274
|
representing the PINNs.
|
|
1081
1275
|
"""
|
|
1276
|
+
if self.obs_batch_size is None or self.obs_batch_size == self.n:
|
|
1277
|
+
# Avoid unnecessary reshuffling
|
|
1278
|
+
return self, {
|
|
1279
|
+
"pinn_in": self.observed_pinn_in,
|
|
1280
|
+
"val": self.observed_values,
|
|
1281
|
+
"eq_params": self.observed_eq_params,
|
|
1282
|
+
}
|
|
1082
1283
|
|
|
1083
1284
|
new_attributes = _reset_or_increment(
|
|
1084
1285
|
self.curr_idx + self.obs_batch_size, self.n, self._get_operands()
|
|
@@ -1120,7 +1321,9 @@ class DataGeneratorObservations(eqx.Module):
|
|
|
1120
1321
|
|
|
1121
1322
|
class DataGeneratorParameter(eqx.Module):
|
|
1122
1323
|
r"""
|
|
1123
|
-
A data generator for additional unidimensional parameter(s)
|
|
1324
|
+
A data generator for additional unidimensional equation parameter(s).
|
|
1325
|
+
Mostly useful for metamodeling where batch of `params.eq_params` are fed
|
|
1326
|
+
to the network.
|
|
1124
1327
|
|
|
1125
1328
|
Parameters
|
|
1126
1329
|
----------
|
|
@@ -1131,19 +1334,11 @@ class DataGeneratorParameter(eqx.Module):
|
|
|
1131
1334
|
The number of total points that will be divided in
|
|
1132
1335
|
batches. Batches are made so that each data point is seen only
|
|
1133
1336
|
once during 1 epoch.
|
|
1134
|
-
param_batch_size : Int
|
|
1337
|
+
param_batch_size : Int | None, default=None
|
|
1135
1338
|
The size of the batch of randomly selected points among
|
|
1136
|
-
the `n` points.
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
`param_batch_size` must be equal to `temporal_batch_size` or
|
|
1140
|
-
`omega_batch_size` or the product of both. In the first case, the
|
|
1141
|
-
present DataGeneratorParameter instance complements an ODEBatch, a
|
|
1142
|
-
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1143
|
-
= False). In the second case, `param_batch_size` =
|
|
1144
|
-
`temporal_batch_size * omega_batch_size` if the present
|
|
1145
|
-
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1146
|
-
with self.cartesian_product = True
|
|
1339
|
+
the `n` points. **Important**: no check is performed but
|
|
1340
|
+
`param_batch_size` must be the same as other collocation points
|
|
1341
|
+
batch_size (time, space or timexspace depending on the context). This is because we vmap the network on all its axes at once to compute the MSE. Also, `param_batch_size` will be the same for all parameters. If None, no mini-batches are used.
|
|
1147
1342
|
param_ranges : Dict[str, tuple[Float, Float] | None, default={}
|
|
1148
1343
|
A dict. A dict of tuples (min, max), which
|
|
1149
1344
|
reprensents the range of real numbers where to sample batches (of
|
|
@@ -1153,31 +1348,31 @@ class DataGeneratorParameter(eqx.Module):
|
|
|
1153
1348
|
By providing several entries in this dictionary we can sample
|
|
1154
1349
|
an arbitrary number of parameters.
|
|
1155
1350
|
**Note** that we currently only support unidimensional parameters.
|
|
1156
|
-
This argument can be
|
|
1351
|
+
This argument can be None if we use `user_data`.
|
|
1157
1352
|
method : str, default="uniform"
|
|
1158
1353
|
Either `grid` or `uniform`, default is `uniform`. `grid` means
|
|
1159
1354
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
1160
1355
|
sampled points over the domain
|
|
1161
|
-
user_data : Dict[str, Float[
|
|
1356
|
+
user_data : Dict[str, Float[jnp.ndarray, "n"]] | None, default={}
|
|
1162
1357
|
A dictionary containing user-provided data for parameters.
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
unidimensional
|
|
1166
|
-
|
|
1358
|
+
The keys corresponds to the parameter name,
|
|
1359
|
+
and must match the keys in `params["eq_params"]`. Only
|
|
1360
|
+
unidimensional `jnp.array` are supported. Therefore, the array at
|
|
1361
|
+
`user_data[k]` must have shape `(n, 1)` or `(n,)`.
|
|
1167
1362
|
Note that if the same key appears in `param_ranges` and `user_data`
|
|
1168
1363
|
priority goes for the content in `user_data`.
|
|
1169
1364
|
Defaults to None.
|
|
1170
1365
|
"""
|
|
1171
1366
|
|
|
1172
1367
|
keys: Key | Dict[str, Key]
|
|
1173
|
-
n: Int
|
|
1174
|
-
param_batch_size: Int = eqx.field(static=True)
|
|
1368
|
+
n: Int = eqx.field(static=True)
|
|
1369
|
+
param_batch_size: Int | None = eqx.field(static=True, default=None)
|
|
1175
1370
|
param_ranges: Dict[str, tuple[Float, Float]] = eqx.field(
|
|
1176
1371
|
static=True, default_factory=lambda: {}
|
|
1177
1372
|
)
|
|
1178
1373
|
method: str = eqx.field(static=True, default="uniform")
|
|
1179
|
-
user_data: Dict[str, Float[Array, "n"]] | None = eqx.field(
|
|
1180
|
-
|
|
1374
|
+
user_data: Dict[str, Float[onp.Array, "n"]] | None = eqx.field(
|
|
1375
|
+
default_factory=lambda: {}
|
|
1181
1376
|
)
|
|
1182
1377
|
|
|
1183
1378
|
curr_param_idx: Dict[str, Int] = eqx.field(init=False)
|
|
@@ -1197,11 +1392,13 @@ class DataGeneratorParameter(eqx.Module):
|
|
|
1197
1392
|
all_keys = set().union(self.param_ranges, self.user_data)
|
|
1198
1393
|
self.keys = dict(zip(all_keys, jax.random.split(self.keys, len(all_keys))))
|
|
1199
1394
|
|
|
1200
|
-
self.
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
)
|
|
1395
|
+
if self.param_batch_size is None:
|
|
1396
|
+
self.curr_param_idx = None
|
|
1397
|
+
else:
|
|
1398
|
+
self.curr_param_idx = {}
|
|
1399
|
+
for k in self.keys.keys():
|
|
1400
|
+
self.curr_param_idx[k] = self.n + self.param_batch_size
|
|
1401
|
+
# to be sure there is a shuffling at first get_batch()
|
|
1205
1402
|
|
|
1206
1403
|
# The call to self.generate_data() creates
|
|
1207
1404
|
# the dict self.param_n_samples and then we will only use this one
|
|
@@ -1268,6 +1465,9 @@ class DataGeneratorParameter(eqx.Module):
|
|
|
1268
1465
|
otherwise we just return the next unseen batch.
|
|
1269
1466
|
"""
|
|
1270
1467
|
|
|
1468
|
+
if self.param_batch_size is None or self.param_batch_size == self.n:
|
|
1469
|
+
return self, self.param_n_samples
|
|
1470
|
+
|
|
1271
1471
|
def _reset_or_increment_wrapper(param_k, idx_k, key_k):
|
|
1272
1472
|
return _reset_or_increment(
|
|
1273
1473
|
idx_k + self.param_batch_size,
|
|
@@ -1319,7 +1519,7 @@ class DataGeneratorObservationsMultiPINNs(eqx.Module):
|
|
|
1319
1519
|
|
|
1320
1520
|
Technically, the constraint on the observations in SystemLossXDE are
|
|
1321
1521
|
applied in `constraints_system_loss_apply` and in this case the
|
|
1322
|
-
batch.obs_batch_dict is a dict of obs_batch_dict over which the tree_map
|
|
1522
|
+
`batch.obs_batch_dict` is a dict of obs_batch_dict over which the tree_map
|
|
1323
1523
|
applies (we select the obs_batch_dict corresponding to its `u_dict` entry)
|
|
1324
1524
|
|
|
1325
1525
|
Parameters
|
|
@@ -1328,15 +1528,6 @@ class DataGeneratorObservationsMultiPINNs(eqx.Module):
|
|
|
1328
1528
|
The size of the batch of randomly selected observations
|
|
1329
1529
|
`obs_batch_size` will be the same for all the
|
|
1330
1530
|
elements of the obs dict.
|
|
1331
|
-
NOTE: no check is done BUT users should be careful that
|
|
1332
|
-
`obs_batch_size` must be equal to `temporal_batch_size` or
|
|
1333
|
-
`omega_batch_size` or the product of both. In the first case, the
|
|
1334
|
-
present DataGeneratorObservations instance complements an ODEBatch,
|
|
1335
|
-
PDEStatioBatch or a PDENonStatioBatch (with self.cartesian_product
|
|
1336
|
-
= False). In the second case, `obs_batch_size` =
|
|
1337
|
-
`temporal_batch_size * omega_batch_size` if the present
|
|
1338
|
-
DataGeneratorParameter complements a PDENonStatioBatch
|
|
1339
|
-
with self.cartesian_product = True
|
|
1340
1531
|
observed_pinn_in_dict : Dict[str, Float[Array, "n_obs nb_pinn_in"] | None]
|
|
1341
1532
|
A dict of observed_pinn_in as defined in DataGeneratorObservations.
|
|
1342
1533
|
Keys must be that of `u_dict`.
|