jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
|
@@ -8,8 +8,11 @@ from __future__ import (
|
|
|
8
8
|
import warnings
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
import jax
|
|
11
|
+
import numpy as np
|
|
11
12
|
import jax.numpy as jnp
|
|
12
|
-
from
|
|
13
|
+
from scipy.stats import qmc
|
|
14
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
15
|
+
from typing import Literal
|
|
13
16
|
from jinns.data._Batchs import PDEStatioBatch
|
|
14
17
|
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
18
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
@@ -22,7 +25,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
22
25
|
|
|
23
26
|
Parameters
|
|
24
27
|
----------
|
|
25
|
-
key :
|
|
28
|
+
key : PRNGKeyArray
|
|
26
29
|
Jax random key to sample new time points and to shuffle batches
|
|
27
30
|
n : int
|
|
28
31
|
The number of total $\Omega$ points that will be divided in
|
|
@@ -50,11 +53,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
50
53
|
A tuple of maximum values of the domain along each dimension. For a sampling
|
|
51
54
|
in `n` dimension, this represents $(x_{1, max}, x_{2,max}, ...,
|
|
52
55
|
x_{n,max})$
|
|
53
|
-
method :
|
|
54
|
-
Either
|
|
56
|
+
method : Literal["grid", "uniform", "sobol", "halton"], default="uniform"
|
|
57
|
+
Either "grid", "uniform", "sobol" or "halton", default is `uniform`.
|
|
55
58
|
The method that generates the `nt` time points. `grid` means
|
|
56
59
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
57
|
-
sampled points over the domain
|
|
60
|
+
sampled points over the domain.
|
|
61
|
+
**Note** that Sobol and Halton approaches use scipy modules and will not
|
|
62
|
+
be JIT compatible.
|
|
58
63
|
rar_parameters : dict[str, int], default=None
|
|
59
64
|
Defaults to None: do not use Residual Adaptative Resampling.
|
|
60
65
|
Otherwise a dictionary with keys
|
|
@@ -75,32 +80,28 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
75
80
|
then corresponds to the initial number of points we train the PINN on.
|
|
76
81
|
"""
|
|
77
82
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
n: int = eqx.field(kw_only=True, static=True)
|
|
83
|
+
key: PRNGKeyArray
|
|
84
|
+
n: int = eqx.field(static=True)
|
|
81
85
|
nb: int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
82
86
|
omega_batch_size: int | None = eqx.field(
|
|
83
|
-
kw_only=True,
|
|
84
87
|
static=True,
|
|
85
|
-
|
|
88
|
+
# can be None as
|
|
86
89
|
# CubicMeshPDENonStatio inherits but also if omega_batch_size=n
|
|
87
90
|
) # static cause used as a
|
|
88
91
|
# shape in jax.lax.dynamic_slice
|
|
89
92
|
omega_border_batch_size: int | None = eqx.field(
|
|
90
|
-
|
|
93
|
+
static=True,
|
|
91
94
|
) # static cause used as a
|
|
92
95
|
# shape in jax.lax.dynamic_slice
|
|
93
|
-
dim: int = eqx.field(
|
|
96
|
+
dim: int = eqx.field(static=True) # static cause used as a
|
|
94
97
|
# shape in jax.lax.dynamic_slice
|
|
95
|
-
min_pts: tuple[float, ...]
|
|
96
|
-
max_pts: tuple[float, ...]
|
|
97
|
-
method:
|
|
98
|
-
|
|
99
|
-
)
|
|
100
|
-
rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
|
|
101
|
-
n_start: int = eqx.field(kw_only=True, default=None, static=True)
|
|
98
|
+
min_pts: tuple[float, ...]
|
|
99
|
+
max_pts: tuple[float, ...]
|
|
100
|
+
method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(static=True)
|
|
101
|
+
rar_parameters: None | dict[str, int]
|
|
102
|
+
n_start: int = eqx.field(static=True)
|
|
102
103
|
|
|
103
|
-
#
|
|
104
|
+
# --- Below fields are not passed as arguments to __init__
|
|
104
105
|
p: Float[Array, " n"] | None = eqx.field(init=False)
|
|
105
106
|
rar_iter_from_last_sampling: int | None = eqx.field(init=False)
|
|
106
107
|
rar_iter_nb: int | None = eqx.field(init=False)
|
|
@@ -111,7 +112,32 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
111
112
|
eqx.field(init=False)
|
|
112
113
|
)
|
|
113
114
|
|
|
114
|
-
def
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
*,
|
|
118
|
+
key: PRNGKeyArray,
|
|
119
|
+
n: int,
|
|
120
|
+
nb: int | None = None,
|
|
121
|
+
omega_batch_size: int | None = None,
|
|
122
|
+
omega_border_batch_size: int | None = None,
|
|
123
|
+
dim: int,
|
|
124
|
+
min_pts: tuple[float, ...],
|
|
125
|
+
max_pts: tuple[float, ...],
|
|
126
|
+
method: Literal["grid", "uniform", "sobol", "halton"] = "uniform",
|
|
127
|
+
rar_parameters: dict[str, int] | None = None,
|
|
128
|
+
n_start: int | None = None,
|
|
129
|
+
):
|
|
130
|
+
self.key = key
|
|
131
|
+
self.n = n
|
|
132
|
+
self.nb = nb
|
|
133
|
+
self.omega_batch_size = omega_batch_size
|
|
134
|
+
self.omega_border_batch_size = omega_border_batch_size
|
|
135
|
+
self.dim = dim
|
|
136
|
+
self.min_pts = min_pts
|
|
137
|
+
self.max_pts = max_pts
|
|
138
|
+
self.method = method
|
|
139
|
+
self.rar_parameters = rar_parameters
|
|
140
|
+
|
|
115
141
|
assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
|
|
116
142
|
assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
|
|
117
143
|
|
|
@@ -120,7 +146,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
120
146
|
self.p,
|
|
121
147
|
self.rar_iter_from_last_sampling,
|
|
122
148
|
self.rar_iter_nb,
|
|
123
|
-
) = _check_and_set_rar_parameters(self.rar_parameters, self.n,
|
|
149
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, n_start)
|
|
124
150
|
|
|
125
151
|
if self.method == "grid" and self.dim == 2:
|
|
126
152
|
perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
@@ -132,6 +158,22 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
132
158
|
)
|
|
133
159
|
self.n = perfect_sq
|
|
134
160
|
|
|
161
|
+
if self.method in ["sobol", "halton"]:
|
|
162
|
+
log2_n = jnp.log2(self.n)
|
|
163
|
+
lower_pow = 2 ** jnp.floor(log2_n)
|
|
164
|
+
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
165
|
+
closest_two_power = (
|
|
166
|
+
lower_pow
|
|
167
|
+
if (self.n - lower_pow) < (higher_pow - self.n)
|
|
168
|
+
else higher_pow
|
|
169
|
+
)
|
|
170
|
+
if self.n != closest_two_power:
|
|
171
|
+
warnings.warn(
|
|
172
|
+
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
173
|
+
f"Modfiying self.n from {self.n} to {closest_two_power}.",
|
|
174
|
+
)
|
|
175
|
+
self.n = int(closest_two_power)
|
|
176
|
+
|
|
135
177
|
if self.omega_batch_size is None:
|
|
136
178
|
self.curr_omega_idx = 0
|
|
137
179
|
else:
|
|
@@ -174,29 +216,53 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
174
216
|
self.key, self.omega_border = self.generate_omega_border_data(self.key)
|
|
175
217
|
|
|
176
218
|
def sample_in_omega_domain(
|
|
177
|
-
self, keys:
|
|
219
|
+
self, keys: list[PRNGKeyArray], sample_size: int
|
|
178
220
|
) -> Float[Array, " n dim"]:
|
|
221
|
+
if self.method == "uniform":
|
|
222
|
+
if self.dim == 1:
|
|
223
|
+
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
224
|
+
return jax.random.uniform(
|
|
225
|
+
*keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return jnp.concatenate(
|
|
229
|
+
[
|
|
230
|
+
jax.random.uniform(
|
|
231
|
+
keys[i],
|
|
232
|
+
(sample_size, 1),
|
|
233
|
+
minval=self.min_pts[i],
|
|
234
|
+
maxval=self.max_pts[i],
|
|
235
|
+
)
|
|
236
|
+
for i in range(self.dim)
|
|
237
|
+
],
|
|
238
|
+
axis=-1,
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
return self._qmc_in_omega_domain(keys[0], sample_size)
|
|
242
|
+
|
|
243
|
+
def _qmc_in_omega_domain(
|
|
244
|
+
self, subkey: PRNGKeyArray, sample_size: int
|
|
245
|
+
) -> Float[Array, "n dim"]:
|
|
246
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
179
247
|
if self.dim == 1:
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
248
|
+
qmc_seq = qmc_generator(
|
|
249
|
+
d=self.dim,
|
|
250
|
+
scramble=True,
|
|
251
|
+
rng=np.random.default_rng(np.uint32(subkey)),
|
|
183
252
|
)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
minval=self.min_pts[i],
|
|
191
|
-
maxval=self.max_pts[i],
|
|
192
|
-
)
|
|
193
|
-
for i in range(self.dim)
|
|
194
|
-
],
|
|
195
|
-
axis=-1,
|
|
253
|
+
u = qmc_seq.random(n=sample_size)
|
|
254
|
+
return jnp.array(
|
|
255
|
+
qmc.scale(u, l_bounds=self.min_pts[0], u_bounds=self.max_pts[0])
|
|
256
|
+
)
|
|
257
|
+
sampler = qmc.Sobol(
|
|
258
|
+
d=self.dim, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
196
259
|
)
|
|
260
|
+
samples = sampler.random(n=sample_size)
|
|
261
|
+
samples = qmc.scale(samples, l_bounds=self.min_pts, u_bounds=self.max_pts)
|
|
262
|
+
return jnp.array(samples)
|
|
197
263
|
|
|
198
264
|
def sample_in_omega_border_domain(
|
|
199
|
-
self, keys:
|
|
265
|
+
self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
|
|
200
266
|
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
201
267
|
sample_size = self.nb if sample_size is None else sample_size
|
|
202
268
|
if sample_size is None:
|
|
@@ -206,6 +272,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
206
272
|
xmax = self.max_pts[0]
|
|
207
273
|
return jnp.array([xmin, xmax]).astype(float)
|
|
208
274
|
if self.dim == 2:
|
|
275
|
+
assert keys is not None
|
|
209
276
|
# currently hard-coded the 4 edges for d==2
|
|
210
277
|
# TODO : find a general & efficient way to sample from the border
|
|
211
278
|
# (facets) of the hypercube in general dim.
|
|
@@ -260,10 +327,67 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
260
327
|
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
261
328
|
)
|
|
262
329
|
|
|
330
|
+
def qmc_in_omega_border_domain(
|
|
331
|
+
self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
|
|
332
|
+
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
333
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
334
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
335
|
+
if sample_size is None:
|
|
336
|
+
return None
|
|
337
|
+
if self.dim == 1:
|
|
338
|
+
xmin = self.min_pts[0]
|
|
339
|
+
xmax = self.max_pts[0]
|
|
340
|
+
return jnp.array([xmin, xmax]).astype(float)
|
|
341
|
+
if self.dim == 2:
|
|
342
|
+
assert keys is not None
|
|
343
|
+
# currently hard-coded the 4 edges for d==2
|
|
344
|
+
# TODO : find a general & efficient way to sample from the border
|
|
345
|
+
# (facets) of the hypercube in general dim.
|
|
346
|
+
facet_n = sample_size // (2 * self.dim)
|
|
347
|
+
|
|
348
|
+
def generate_qmc_sample(key, min_val, max_val):
|
|
349
|
+
qmc_seq = qmc_generator(
|
|
350
|
+
d=1,
|
|
351
|
+
scramble=True,
|
|
352
|
+
rng=np.random.default_rng(np.uint32(key)),
|
|
353
|
+
)
|
|
354
|
+
u = qmc_seq.random(n=facet_n)
|
|
355
|
+
return jnp.array(qmc.scale(u, l_bounds=min_val, u_bounds=max_val))
|
|
356
|
+
|
|
357
|
+
xmin = jnp.hstack(
|
|
358
|
+
[
|
|
359
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
360
|
+
generate_qmc_sample(keys[0], self.min_pts[1], self.max_pts[1]),
|
|
361
|
+
]
|
|
362
|
+
)
|
|
363
|
+
xmax = jnp.hstack(
|
|
364
|
+
[
|
|
365
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
366
|
+
generate_qmc_sample(keys[1], self.min_pts[1], self.max_pts[1]),
|
|
367
|
+
]
|
|
368
|
+
)
|
|
369
|
+
ymin = jnp.hstack(
|
|
370
|
+
[
|
|
371
|
+
generate_qmc_sample(keys[2], self.min_pts[0], self.max_pts[0]),
|
|
372
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
373
|
+
]
|
|
374
|
+
)
|
|
375
|
+
ymax = jnp.hstack(
|
|
376
|
+
[
|
|
377
|
+
generate_qmc_sample(keys[3], self.min_pts[0], self.max_pts[0]),
|
|
378
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
return jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
382
|
+
raise NotImplementedError(
|
|
383
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
384
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
385
|
+
)
|
|
386
|
+
|
|
263
387
|
def generate_omega_data(
|
|
264
|
-
self, key:
|
|
388
|
+
self, key: PRNGKeyArray, data_size: int | None = None
|
|
265
389
|
) -> tuple[
|
|
266
|
-
|
|
390
|
+
PRNGKeyArray,
|
|
267
391
|
Float[Array, " n dim"],
|
|
268
392
|
]:
|
|
269
393
|
r"""
|
|
@@ -290,20 +414,21 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
290
414
|
)
|
|
291
415
|
xyz_ = [a.reshape((data_size, 1)) for a in xyz_]
|
|
292
416
|
omega = jnp.concatenate(xyz_, axis=-1)
|
|
293
|
-
elif self.method
|
|
294
|
-
if self.dim == 1:
|
|
295
|
-
key,
|
|
417
|
+
elif self.method in ["uniform", "sobol", "halton"]:
|
|
418
|
+
if self.dim == 1 or self.method in ["sobol", "halton"]:
|
|
419
|
+
key, subkey = jax.random.split(key, 2)
|
|
420
|
+
omega = self.sample_in_omega_domain([subkey], sample_size=data_size)
|
|
296
421
|
else:
|
|
297
422
|
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
298
|
-
|
|
423
|
+
omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
|
|
299
424
|
else:
|
|
300
425
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
301
426
|
return key, omega
|
|
302
427
|
|
|
303
428
|
def generate_omega_border_data(
|
|
304
|
-
self, key:
|
|
429
|
+
self, key: PRNGKeyArray, data_size: int | None = None
|
|
305
430
|
) -> tuple[
|
|
306
|
-
|
|
431
|
+
PRNGKeyArray,
|
|
307
432
|
Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
|
|
308
433
|
]:
|
|
309
434
|
r"""
|
|
@@ -317,15 +442,24 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
317
442
|
key, *subkeys = jax.random.split(key, 5)
|
|
318
443
|
else:
|
|
319
444
|
subkeys = None
|
|
320
|
-
omega_border = self.sample_in_omega_border_domain(
|
|
321
|
-
subkeys, sample_size=data_size
|
|
322
|
-
)
|
|
323
445
|
|
|
446
|
+
if self.method in ["grid", "uniform"]:
|
|
447
|
+
omega_border = self.sample_in_omega_border_domain(
|
|
448
|
+
subkeys, sample_size=data_size
|
|
449
|
+
)
|
|
450
|
+
elif self.method in ["sobol", "halton"]:
|
|
451
|
+
omega_border = self.qmc_in_omega_border_domain(
|
|
452
|
+
subkeys, sample_size=data_size
|
|
453
|
+
)
|
|
454
|
+
else:
|
|
455
|
+
raise ValueError("Method " + self.method + " is not implemented.")
|
|
324
456
|
return key, omega_border
|
|
325
457
|
|
|
326
458
|
def _get_omega_operands(
|
|
327
459
|
self,
|
|
328
|
-
) -> tuple[
|
|
460
|
+
) -> tuple[
|
|
461
|
+
PRNGKeyArray, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None
|
|
462
|
+
]:
|
|
329
463
|
return (
|
|
330
464
|
self.key,
|
|
331
465
|
self.omega,
|
|
@@ -367,7 +501,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
367
501
|
# handled above
|
|
368
502
|
)
|
|
369
503
|
new = eqx.tree_at(
|
|
370
|
-
lambda m: (m.key, m.omega, m.curr_omega_idx),
|
|
504
|
+
lambda m: (m.key, m.omega, m.curr_omega_idx), # type: ignore
|
|
505
|
+
self,
|
|
506
|
+
new_attributes,
|
|
371
507
|
)
|
|
372
508
|
|
|
373
509
|
return new, jax.lax.dynamic_slice(
|
|
@@ -379,7 +515,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
379
515
|
def _get_omega_border_operands(
|
|
380
516
|
self,
|
|
381
517
|
) -> tuple[
|
|
382
|
-
|
|
518
|
+
PRNGKeyArray,
|
|
383
519
|
Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
|
|
384
520
|
int,
|
|
385
521
|
int | None,
|
|
@@ -443,7 +579,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
443
579
|
# handled above
|
|
444
580
|
)
|
|
445
581
|
new = eqx.tree_at(
|
|
446
|
-
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
|
|
582
|
+
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx), # type: ignore
|
|
447
583
|
self,
|
|
448
584
|
new_attributes,
|
|
449
585
|
)
|
jinns/data/_DataGeneratorODE.py
CHANGED
|
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
|
|
|
9
9
|
import equinox as eqx
|
|
10
10
|
import jax
|
|
11
11
|
import jax.numpy as jnp
|
|
12
|
-
from jaxtyping import
|
|
12
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
13
13
|
from jinns.data._Batchs import ODEBatch
|
|
14
14
|
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
15
15
|
from jinns.data._AbstractDataGenerator import AbstractDataGenerator
|
|
@@ -24,7 +24,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
24
24
|
|
|
25
25
|
Parameters
|
|
26
26
|
----------
|
|
27
|
-
key :
|
|
27
|
+
key : PRNGKeyArray
|
|
28
28
|
Jax random key to sample new time points and to shuffle batches
|
|
29
29
|
nt : int
|
|
30
30
|
The number of total time points that will be divided in
|
|
@@ -42,10 +42,10 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
42
42
|
The method that generates the `nt` time points. `grid` means
|
|
43
43
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
44
44
|
sampled points over the domain
|
|
45
|
-
rar_parameters : RarParameterDict, default=None
|
|
45
|
+
rar_parameters : None | RarParameterDict, default=None
|
|
46
46
|
A TypedDict to specify the Residual Adaptative Resampling procedure. See
|
|
47
47
|
the docstring from RarParameterDict
|
|
48
|
-
n_start : int, default=None
|
|
48
|
+
n_start : None | int, default=None
|
|
49
49
|
Defaults to None. The effective size of nt used at start time.
|
|
50
50
|
This value must be
|
|
51
51
|
provided when rar_parameters is not None. Otherwise we set internally
|
|
@@ -54,25 +54,43 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
54
54
|
then corresponds to the initial number of points we train the PINN.
|
|
55
55
|
"""
|
|
56
56
|
|
|
57
|
-
key:
|
|
58
|
-
nt: int = eqx.field(
|
|
59
|
-
tmin:
|
|
60
|
-
tmax:
|
|
61
|
-
temporal_batch_size: int | None = eqx.field(static=True
|
|
62
|
-
method: str = eqx.field(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
# all the init=False fields are set in __post_init__
|
|
57
|
+
key: PRNGKeyArray
|
|
58
|
+
nt: int = eqx.field(static=True)
|
|
59
|
+
tmin: float
|
|
60
|
+
tmax: float
|
|
61
|
+
temporal_batch_size: int | None = eqx.field(static=True)
|
|
62
|
+
method: str = eqx.field(static=True)
|
|
63
|
+
rar_parameters: None | dict[str, int]
|
|
64
|
+
n_start: None | int
|
|
65
|
+
|
|
66
|
+
# --- Below fields are not passed as arguments to __init__
|
|
69
67
|
p: Float[Array, " nt 1"] | None = eqx.field(init=False)
|
|
70
68
|
rar_iter_from_last_sampling: int | None = eqx.field(init=False)
|
|
71
69
|
rar_iter_nb: int | None = eqx.field(init=False)
|
|
72
70
|
curr_time_idx: int = eqx.field(init=False)
|
|
73
71
|
times: Float[Array, " nt 1"] = eqx.field(init=False)
|
|
74
72
|
|
|
75
|
-
def
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
*,
|
|
76
|
+
key: PRNGKeyArray,
|
|
77
|
+
nt: int,
|
|
78
|
+
tmin: float,
|
|
79
|
+
tmax: float,
|
|
80
|
+
temporal_batch_size: int | None,
|
|
81
|
+
method: str = "uniform",
|
|
82
|
+
rar_parameters: None | dict[str, int] = None,
|
|
83
|
+
n_start: None | int = None,
|
|
84
|
+
):
|
|
85
|
+
self.key = key
|
|
86
|
+
self.nt = nt
|
|
87
|
+
self.tmin = tmin
|
|
88
|
+
self.tmax = tmax
|
|
89
|
+
self.temporal_batch_size = temporal_batch_size
|
|
90
|
+
self.method = method
|
|
91
|
+
self.n_start = n_start
|
|
92
|
+
self.rar_parameters = rar_parameters
|
|
93
|
+
|
|
76
94
|
(
|
|
77
95
|
self.n_start,
|
|
78
96
|
self.p,
|
|
@@ -97,7 +115,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
97
115
|
# above way for the key.
|
|
98
116
|
|
|
99
117
|
def sample_in_time_domain(
|
|
100
|
-
self, key:
|
|
118
|
+
self, key: PRNGKeyArray, sample_size: int | None = None
|
|
101
119
|
) -> Float[Array, " nt 1"]:
|
|
102
120
|
return jax.random.uniform(
|
|
103
121
|
key,
|
|
@@ -106,7 +124,9 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
106
124
|
maxval=self.tmax,
|
|
107
125
|
)
|
|
108
126
|
|
|
109
|
-
def generate_time_data(
|
|
127
|
+
def generate_time_data(
|
|
128
|
+
self, key: PRNGKeyArray
|
|
129
|
+
) -> tuple[PRNGKeyArray, Float[Array, " nt"]]:
|
|
110
130
|
"""
|
|
111
131
|
Construct a complete set of `self.nt` time points according to the
|
|
112
132
|
specified `self.method`
|
|
@@ -125,7 +145,11 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
125
145
|
def _get_time_operands(
|
|
126
146
|
self,
|
|
127
147
|
) -> tuple[
|
|
128
|
-
|
|
148
|
+
PRNGKeyArray,
|
|
149
|
+
Float[Array, " nt 1"],
|
|
150
|
+
int,
|
|
151
|
+
int | None,
|
|
152
|
+
Float[Array, " nt 1"] | None,
|
|
129
153
|
]:
|
|
130
154
|
return (
|
|
131
155
|
self.key,
|
|
@@ -150,7 +174,7 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
150
174
|
bend = bstart + self.temporal_batch_size
|
|
151
175
|
|
|
152
176
|
# Compute the effective number of used collocation points
|
|
153
|
-
if self.rar_parameters is not None:
|
|
177
|
+
if self.rar_parameters is not None and self.n_start is not None:
|
|
154
178
|
nt_eff = (
|
|
155
179
|
self.n_start
|
|
156
180
|
+ self.rar_iter_nb # type: ignore
|
|
@@ -167,7 +191,9 @@ class DataGeneratorODE(AbstractDataGenerator):
|
|
|
167
191
|
# handled above
|
|
168
192
|
)
|
|
169
193
|
new = eqx.tree_at(
|
|
170
|
-
lambda m: (m.key, m.times, m.curr_time_idx),
|
|
194
|
+
lambda m: (m.key, m.times, m.curr_time_idx), # type: ignore
|
|
195
|
+
self,
|
|
196
|
+
new_attributes,
|
|
171
197
|
)
|
|
172
198
|
|
|
173
199
|
# commands below are equivalent to
|