jinns 1.5.1__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +55 -34
- jinns/data/_CubicMeshPDEStatio.py +63 -35
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +86 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +139 -184
- jinns/loss/_LossPDE.py +440 -358
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +50 -25
- jinns/solver/_solve.py +3 -3
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/METADATA +2 -2
- jinns-1.6.1.dist-info/RECORD +57 -0
- jinns-1.5.1.dist-info/RECORD +0 -55
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/WHEEL +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.1.dist-info → jinns-1.6.1.dist-info}/top_level.txt +0 -0
jinns/data/_Batchs.py
CHANGED
|
@@ -18,25 +18,59 @@ class ObsBatchDict(TypedDict):
|
|
|
18
18
|
|
|
19
19
|
pinn_in: Float[Array, " obs_batch_size input_dim"]
|
|
20
20
|
val: Float[Array, " obs_batch_size output_dim"]
|
|
21
|
-
eq_params:
|
|
21
|
+
eq_params: (
|
|
22
|
+
eqx.Module | None
|
|
23
|
+
) # None cause sometime user don't provide observed params
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class ODEBatch(eqx.Module):
|
|
25
27
|
temporal_batch: Float[Array, " batch_size"]
|
|
26
|
-
param_batch_dict:
|
|
27
|
-
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class PDENonStatioBatch(eqx.Module):
|
|
31
|
-
domain_batch: Float[Array, " batch_size 1+dimension"]
|
|
32
|
-
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
33
|
-
initial_batch: Float[Array, " batch_size dimension"] | None
|
|
34
|
-
param_batch_dict: dict[str, Array] = eqx.field(default=None)
|
|
35
|
-
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
28
|
+
param_batch_dict: eqx.Module | None = eqx.field(default=None)
|
|
29
|
+
obs_batch_dict: ObsBatchDict | None = eqx.field(default=None)
|
|
36
30
|
|
|
37
31
|
|
|
38
32
|
class PDEStatioBatch(eqx.Module):
|
|
39
33
|
domain_batch: Float[Array, " batch_size dimension"]
|
|
40
34
|
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
41
|
-
param_batch_dict:
|
|
42
|
-
obs_batch_dict: ObsBatchDict
|
|
35
|
+
param_batch_dict: eqx.Module | None
|
|
36
|
+
obs_batch_dict: ObsBatchDict | None
|
|
37
|
+
|
|
38
|
+
# rewrite __init__ to be able to use inheritance for the NonStatio case
|
|
39
|
+
# below. That way PDENonStatioBatch is a subtype of PDEStatioBatch which
|
|
40
|
+
# 1) makes more sense and 2) CubicMeshPDENonStatio.get_batch passes pyright.
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
domain_batch: Float[Array, " batch_size dimension"],
|
|
45
|
+
border_batch: Float[Array, " batch_size dimension n_facets"] | None,
|
|
46
|
+
param_batch_dict: eqx.Module | None = None,
|
|
47
|
+
obs_batch_dict: ObsBatchDict | None = None,
|
|
48
|
+
):
|
|
49
|
+
# TODO: document this ?
|
|
50
|
+
self.domain_batch = domain_batch
|
|
51
|
+
self.border_batch = border_batch
|
|
52
|
+
self.param_batch_dict = param_batch_dict
|
|
53
|
+
self.obs_batch_dict = obs_batch_dict
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PDENonStatioBatch(PDEStatioBatch):
|
|
57
|
+
# TODO: document this ?
|
|
58
|
+
domain_batch: Float[Array, " batch_size 1+dimension"] # Override type
|
|
59
|
+
initial_batch: (
|
|
60
|
+
Float[Array, " batch_size dimension"] | None
|
|
61
|
+
) # why can it be None ? Examples?
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
*,
|
|
66
|
+
domain_batch: Float[Array, " batch_size 1+dimension"],
|
|
67
|
+
border_batch: Float[Array, " batch_size dimension n_facets"] | None,
|
|
68
|
+
initial_batch: Float[Array, " batch_size dimension"] | None,
|
|
69
|
+
param_batch_dict: eqx.Module | None = None,
|
|
70
|
+
obs_batch_dict: ObsBatchDict | None = None,
|
|
71
|
+
):
|
|
72
|
+
self.domain_batch = domain_batch
|
|
73
|
+
self.border_batch = border_batch
|
|
74
|
+
self.initial_batch = initial_batch
|
|
75
|
+
self.param_batch_dict = param_batch_dict
|
|
76
|
+
self.obs_batch_dict = obs_batch_dict
|
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
import jax
|
|
12
12
|
import jax.numpy as jnp
|
|
13
13
|
from scipy.stats import qmc
|
|
14
|
-
from jaxtyping import
|
|
14
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
15
15
|
from jinns.data._Batchs import PDENonStatioBatch
|
|
16
16
|
from jinns.data._utils import (
|
|
17
17
|
make_cartesian_product,
|
|
@@ -29,7 +29,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
29
29
|
|
|
30
30
|
Parameters
|
|
31
31
|
----------
|
|
32
|
-
key :
|
|
32
|
+
key : PRNGKeyArray
|
|
33
33
|
Jax random key to sample new time points and to shuffle batches
|
|
34
34
|
n : int
|
|
35
35
|
The number of total $I\times \Omega$ points that will be divided in
|
|
@@ -50,9 +50,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
50
50
|
among the `nb` points. If None, `domain_batch_size` no
|
|
51
51
|
mini-batches are used.
|
|
52
52
|
initial_batch_size : int | None, default=None
|
|
53
|
-
The
|
|
54
|
-
|
|
55
|
-
mini-batches are used.
|
|
53
|
+
The number of randomly selected points among the `ni` initial spatial
|
|
54
|
+
points used for initial condition. If None, no mini-batches are used.
|
|
56
55
|
dim : int
|
|
57
56
|
An integer. Dimension of $\Omega$ domain.
|
|
58
57
|
min_pts : tuple[tuple[Float, Float], ...]
|
|
@@ -94,13 +93,14 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
94
93
|
then corresponds to the initial number of omega points we train the PINN.
|
|
95
94
|
"""
|
|
96
95
|
|
|
97
|
-
tmin:
|
|
98
|
-
tmax:
|
|
99
|
-
ni: int = eqx.field(
|
|
100
|
-
domain_batch_size: int | None = eqx.field(
|
|
101
|
-
initial_batch_size: int | None = eqx.field(
|
|
102
|
-
border_batch_size: int | None = eqx.field(
|
|
96
|
+
tmin: float
|
|
97
|
+
tmax: float
|
|
98
|
+
ni: int = eqx.field(static=True)
|
|
99
|
+
domain_batch_size: int | None = eqx.field(static=True)
|
|
100
|
+
initial_batch_size: int | None = eqx.field(static=True)
|
|
101
|
+
border_batch_size: int | None = eqx.field(static=True)
|
|
103
102
|
|
|
103
|
+
# --- Below fields are not passed as arguments to __init__
|
|
104
104
|
curr_domain_idx: int = eqx.field(init=False)
|
|
105
105
|
curr_initial_idx: int = eqx.field(init=False)
|
|
106
106
|
curr_border_idx: int = eqx.field(init=False)
|
|
@@ -110,13 +110,32 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
110
110
|
)
|
|
111
111
|
initial: Float[Array, " ni dim"] | None = eqx.field(init=False)
|
|
112
112
|
|
|
113
|
-
def
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
tmin: float,
|
|
116
|
+
tmax: float,
|
|
117
|
+
ni: int,
|
|
118
|
+
domain_batch_size: int | None = None,
|
|
119
|
+
initial_batch_size: int | None = None,
|
|
120
|
+
border_batch_size: int | None = None,
|
|
121
|
+
**kwargs, # kwargs for CubicMeshPDEStatio.__init__
|
|
122
|
+
):
|
|
114
123
|
"""
|
|
115
124
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
116
125
|
Module with eqx.tree_at!
|
|
117
126
|
"""
|
|
118
|
-
|
|
119
|
-
|
|
127
|
+
# sanity check
|
|
128
|
+
if ni is None:
|
|
129
|
+
raise ValueError("`ni` cannot be None.")
|
|
130
|
+
|
|
131
|
+
super().__init__(**kwargs)
|
|
132
|
+
self.tmin = tmin
|
|
133
|
+
self.tmax = tmax
|
|
134
|
+
self.ni = ni
|
|
135
|
+
|
|
136
|
+
self.domain_batch_size = domain_batch_size
|
|
137
|
+
self.initial_batch_size = initial_batch_size
|
|
138
|
+
self.border_batch_size = border_batch_size
|
|
120
139
|
|
|
121
140
|
if self.method == "grid":
|
|
122
141
|
# NOTE we must redo the sampling with the square root number of samples
|
|
@@ -144,7 +163,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
144
163
|
)
|
|
145
164
|
self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
|
|
146
165
|
|
|
147
|
-
# NOTE
|
|
166
|
+
# NOTE below re-do CubicMeshPDE.__init__() ? Maybe useless?
|
|
148
167
|
(
|
|
149
168
|
self.n_start,
|
|
150
169
|
self.p,
|
|
@@ -178,7 +197,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
178
197
|
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
179
198
|
)
|
|
180
199
|
# the check below concern omega_border_batch_size for dim > 1 in
|
|
181
|
-
# super.
|
|
200
|
+
# super.__init__. Here it concerns all dim values since our
|
|
182
201
|
# border_batch is the concatenation or cartesian product with times
|
|
183
202
|
if (
|
|
184
203
|
self.border_batch_size is not None
|
|
@@ -221,7 +240,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
221
240
|
self.border_batch_size = None
|
|
222
241
|
self.curr_border_idx = 0
|
|
223
242
|
|
|
224
|
-
if
|
|
243
|
+
if ni is not None:
|
|
225
244
|
if self.method == "grid":
|
|
226
245
|
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
227
246
|
if self.ni != perfect_sq:
|
|
@@ -235,17 +254,17 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
235
254
|
log2_n = jnp.log2(self.ni)
|
|
236
255
|
lower_pow = 2 ** jnp.floor(log2_n)
|
|
237
256
|
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
238
|
-
|
|
257
|
+
closest_power_of_two = (
|
|
239
258
|
lower_pow
|
|
240
259
|
if (self.ni - lower_pow) < (higher_pow - self.ni)
|
|
241
260
|
else higher_pow
|
|
242
261
|
)
|
|
243
|
-
if self.n !=
|
|
262
|
+
if self.n != closest_power_of_two:
|
|
244
263
|
warnings.warn(
|
|
245
264
|
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
246
|
-
f"Modfiying self.n from {self.ni} to {
|
|
265
|
+
f"Modfiying self.n from {self.ni} to {closest_power_of_two}.",
|
|
247
266
|
)
|
|
248
|
-
self.ni = int(
|
|
267
|
+
self.ni = int(closest_power_of_two)
|
|
249
268
|
self.key, self.initial = self.generate_omega_data(
|
|
250
269
|
self.key, data_size=self.ni
|
|
251
270
|
)
|
|
@@ -264,8 +283,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
264
283
|
self.omega_border = None
|
|
265
284
|
|
|
266
285
|
def generate_time_data(
|
|
267
|
-
self, key:
|
|
268
|
-
) -> tuple[
|
|
286
|
+
self, key: PRNGKeyArray, nt: int
|
|
287
|
+
) -> tuple[PRNGKeyArray, Float[Array, " nt 1"]]:
|
|
269
288
|
"""
|
|
270
289
|
Construct a complete set of `nt` time points according to the
|
|
271
290
|
specified `self.method`
|
|
@@ -278,12 +297,14 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
278
297
|
return key, self.sample_in_time_domain(subkey, nt)
|
|
279
298
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
280
299
|
|
|
281
|
-
def sample_in_time_domain(
|
|
300
|
+
def sample_in_time_domain(
|
|
301
|
+
self, key: PRNGKeyArray, nt: int
|
|
302
|
+
) -> Float[Array, " nt 1"]:
|
|
282
303
|
return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
|
|
283
304
|
|
|
284
305
|
def qmc_in_time_omega_domain(
|
|
285
|
-
self, key:
|
|
286
|
-
) -> tuple[
|
|
306
|
+
self, key: PRNGKeyArray, sample_size: int
|
|
307
|
+
) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]]:
|
|
287
308
|
"""
|
|
288
309
|
Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
|
|
289
310
|
We generate time and omega samples jointly
|
|
@@ -300,8 +321,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
300
321
|
return key, jnp.array(samples)
|
|
301
322
|
|
|
302
323
|
def qmc_in_time_omega_border_domain(
|
|
303
|
-
self, key:
|
|
304
|
-
) -> tuple[
|
|
324
|
+
self, key: PRNGKeyArray, sample_size: int | None = None
|
|
325
|
+
) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]] | None:
|
|
305
326
|
"""
|
|
306
327
|
For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
|
|
307
328
|
|
|
@@ -387,7 +408,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
387
408
|
|
|
388
409
|
def _get_domain_operands(
|
|
389
410
|
self,
|
|
390
|
-
) -> tuple[
|
|
411
|
+
) -> tuple[PRNGKeyArray, Float[Array, " n 1+dim"], int, int | None, Array | None]:
|
|
391
412
|
return (
|
|
392
413
|
self.key,
|
|
393
414
|
self.domain,
|
|
@@ -424,7 +445,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
424
445
|
# handled above
|
|
425
446
|
)
|
|
426
447
|
new = eqx.tree_at(
|
|
427
|
-
lambda m: (m.key, m.domain, m.curr_domain_idx),
|
|
448
|
+
lambda m: (m.key, m.domain, m.curr_domain_idx), # type: ignore
|
|
428
449
|
self,
|
|
429
450
|
new_attributes,
|
|
430
451
|
)
|
|
@@ -437,7 +458,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
437
458
|
def _get_border_operands(
|
|
438
459
|
self,
|
|
439
460
|
) -> tuple[
|
|
440
|
-
|
|
461
|
+
PRNGKeyArray,
|
|
441
462
|
Float[Array, " nb 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None,
|
|
442
463
|
int,
|
|
443
464
|
int | None,
|
|
@@ -483,7 +504,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
483
504
|
# handled above
|
|
484
505
|
)
|
|
485
506
|
new = eqx.tree_at(
|
|
486
|
-
lambda m: (m.key, m.border, m.curr_border_idx),
|
|
507
|
+
lambda m: (m.key, m.border, m.curr_border_idx), # type: ignore
|
|
487
508
|
self,
|
|
488
509
|
new_attributes,
|
|
489
510
|
)
|
|
@@ -500,7 +521,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
500
521
|
|
|
501
522
|
def _get_initial_operands(
|
|
502
523
|
self,
|
|
503
|
-
) -> tuple[
|
|
524
|
+
) -> tuple[PRNGKeyArray, Float[Array, " ni dim"] | None, int, int | None, None]:
|
|
504
525
|
return (
|
|
505
526
|
self.key,
|
|
506
527
|
self.initial,
|
|
@@ -529,7 +550,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
529
550
|
# handled above
|
|
530
551
|
)
|
|
531
552
|
new = eqx.tree_at(
|
|
532
|
-
lambda m: (m.key, m.initial, m.curr_initial_idx),
|
|
553
|
+
lambda m: (m.key, m.initial, m.curr_initial_idx), # type: ignore
|
|
533
554
|
self,
|
|
534
555
|
new_attributes,
|
|
535
556
|
)
|
|
@@ -11,7 +11,7 @@ import jax
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import jax.numpy as jnp
|
|
13
13
|
from scipy.stats import qmc
|
|
14
|
-
from jaxtyping import
|
|
14
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
15
15
|
from typing import Literal
|
|
16
16
|
from jinns.data._Batchs import PDEStatioBatch
|
|
17
17
|
from jinns.data._utils import _check_and_set_rar_parameters, _reset_or_increment
|
|
@@ -25,7 +25,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
25
25
|
|
|
26
26
|
Parameters
|
|
27
27
|
----------
|
|
28
|
-
key :
|
|
28
|
+
key : PRNGKeyArray
|
|
29
29
|
Jax random key to sample new time points and to shuffle batches
|
|
30
30
|
n : int
|
|
31
31
|
The number of total $\Omega$ points that will be divided in
|
|
@@ -80,32 +80,28 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
80
80
|
then corresponds to the initial number of points we train the PINN on.
|
|
81
81
|
"""
|
|
82
82
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
n: int = eqx.field(kw_only=True, static=True)
|
|
83
|
+
key: PRNGKeyArray
|
|
84
|
+
n: int = eqx.field(static=True)
|
|
86
85
|
nb: int | None = eqx.field(kw_only=True, static=True, default=None)
|
|
87
86
|
omega_batch_size: int | None = eqx.field(
|
|
88
|
-
kw_only=True,
|
|
89
87
|
static=True,
|
|
90
|
-
|
|
88
|
+
# can be None as
|
|
91
89
|
# CubicMeshPDENonStatio inherits but also if omega_batch_size=n
|
|
92
90
|
) # static cause used as a
|
|
93
91
|
# shape in jax.lax.dynamic_slice
|
|
94
92
|
omega_border_batch_size: int | None = eqx.field(
|
|
95
|
-
|
|
93
|
+
static=True,
|
|
96
94
|
) # static cause used as a
|
|
97
95
|
# shape in jax.lax.dynamic_slice
|
|
98
|
-
dim: int = eqx.field(
|
|
96
|
+
dim: int = eqx.field(static=True) # static cause used as a
|
|
99
97
|
# shape in jax.lax.dynamic_slice
|
|
100
|
-
min_pts: tuple[float, ...]
|
|
101
|
-
max_pts: tuple[float, ...]
|
|
102
|
-
method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(
|
|
103
|
-
|
|
104
|
-
)
|
|
105
|
-
rar_parameters: dict[str, int] = eqx.field(kw_only=True, default=None)
|
|
106
|
-
n_start: int = eqx.field(kw_only=True, default=None, static=True)
|
|
98
|
+
min_pts: tuple[float, ...]
|
|
99
|
+
max_pts: tuple[float, ...]
|
|
100
|
+
method: Literal["grid", "uniform", "sobol", "halton"] = eqx.field(static=True)
|
|
101
|
+
rar_parameters: None | dict[str, int]
|
|
102
|
+
n_start: int = eqx.field(static=True)
|
|
107
103
|
|
|
108
|
-
#
|
|
104
|
+
# --- Below fields are not passed as arguments to __init__
|
|
109
105
|
p: Float[Array, " n"] | None = eqx.field(init=False)
|
|
110
106
|
rar_iter_from_last_sampling: int | None = eqx.field(init=False)
|
|
111
107
|
rar_iter_nb: int | None = eqx.field(init=False)
|
|
@@ -116,7 +112,32 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
116
112
|
eqx.field(init=False)
|
|
117
113
|
)
|
|
118
114
|
|
|
119
|
-
def
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
*,
|
|
118
|
+
key: PRNGKeyArray,
|
|
119
|
+
n: int,
|
|
120
|
+
nb: int | None = None,
|
|
121
|
+
omega_batch_size: int | None = None,
|
|
122
|
+
omega_border_batch_size: int | None = None,
|
|
123
|
+
dim: int,
|
|
124
|
+
min_pts: tuple[float, ...],
|
|
125
|
+
max_pts: tuple[float, ...],
|
|
126
|
+
method: Literal["grid", "uniform", "sobol", "halton"] = "uniform",
|
|
127
|
+
rar_parameters: dict[str, int] | None = None,
|
|
128
|
+
n_start: int | None = None,
|
|
129
|
+
):
|
|
130
|
+
self.key = key
|
|
131
|
+
self.n = n
|
|
132
|
+
self.nb = nb
|
|
133
|
+
self.omega_batch_size = omega_batch_size
|
|
134
|
+
self.omega_border_batch_size = omega_border_batch_size
|
|
135
|
+
self.dim = dim
|
|
136
|
+
self.min_pts = min_pts
|
|
137
|
+
self.max_pts = max_pts
|
|
138
|
+
self.method = method
|
|
139
|
+
self.rar_parameters = rar_parameters
|
|
140
|
+
|
|
120
141
|
assert self.dim == len(self.min_pts) and isinstance(self.min_pts, tuple)
|
|
121
142
|
assert self.dim == len(self.max_pts) and isinstance(self.max_pts, tuple)
|
|
122
143
|
|
|
@@ -125,7 +146,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
125
146
|
self.p,
|
|
126
147
|
self.rar_iter_from_last_sampling,
|
|
127
148
|
self.rar_iter_nb,
|
|
128
|
-
) = _check_and_set_rar_parameters(self.rar_parameters, self.n,
|
|
149
|
+
) = _check_and_set_rar_parameters(self.rar_parameters, self.n, n_start)
|
|
129
150
|
|
|
130
151
|
if self.method == "grid" and self.dim == 2:
|
|
131
152
|
perfect_sq = int(jnp.round(jnp.sqrt(self.n)) ** 2)
|
|
@@ -195,13 +216,13 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
195
216
|
self.key, self.omega_border = self.generate_omega_border_data(self.key)
|
|
196
217
|
|
|
197
218
|
def sample_in_omega_domain(
|
|
198
|
-
self, keys:
|
|
219
|
+
self, keys: list[PRNGKeyArray], sample_size: int
|
|
199
220
|
) -> Float[Array, " n dim"]:
|
|
200
221
|
if self.method == "uniform":
|
|
201
222
|
if self.dim == 1:
|
|
202
223
|
xmin, xmax = self.min_pts[0], self.max_pts[0]
|
|
203
224
|
return jax.random.uniform(
|
|
204
|
-
keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
225
|
+
*keys, shape=(sample_size, 1), minval=xmin, maxval=xmax
|
|
205
226
|
)
|
|
206
227
|
|
|
207
228
|
return jnp.concatenate(
|
|
@@ -217,10 +238,10 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
217
238
|
axis=-1,
|
|
218
239
|
)
|
|
219
240
|
else:
|
|
220
|
-
return self._qmc_in_omega_domain(keys, sample_size)
|
|
241
|
+
return self._qmc_in_omega_domain(keys[0], sample_size)
|
|
221
242
|
|
|
222
243
|
def _qmc_in_omega_domain(
|
|
223
|
-
self, subkey:
|
|
244
|
+
self, subkey: PRNGKeyArray, sample_size: int
|
|
224
245
|
) -> Float[Array, "n dim"]:
|
|
225
246
|
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
226
247
|
if self.dim == 1:
|
|
@@ -241,7 +262,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
241
262
|
return jnp.array(samples)
|
|
242
263
|
|
|
243
264
|
def sample_in_omega_border_domain(
|
|
244
|
-
self, keys:
|
|
265
|
+
self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
|
|
245
266
|
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
246
267
|
sample_size = self.nb if sample_size is None else sample_size
|
|
247
268
|
if sample_size is None:
|
|
@@ -251,6 +272,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
251
272
|
xmax = self.max_pts[0]
|
|
252
273
|
return jnp.array([xmin, xmax]).astype(float)
|
|
253
274
|
if self.dim == 2:
|
|
275
|
+
assert keys is not None
|
|
254
276
|
# currently hard-coded the 4 edges for d==2
|
|
255
277
|
# TODO : find a general & efficient way to sample from the border
|
|
256
278
|
# (facets) of the hypercube in general dim.
|
|
@@ -306,7 +328,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
306
328
|
)
|
|
307
329
|
|
|
308
330
|
def qmc_in_omega_border_domain(
|
|
309
|
-
self, keys:
|
|
331
|
+
self, keys: list[PRNGKeyArray] | None, sample_size: int | None = None
|
|
310
332
|
) -> Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None:
|
|
311
333
|
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
312
334
|
sample_size = self.nb if sample_size is None else sample_size
|
|
@@ -317,6 +339,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
317
339
|
xmax = self.max_pts[0]
|
|
318
340
|
return jnp.array([xmin, xmax]).astype(float)
|
|
319
341
|
if self.dim == 2:
|
|
342
|
+
assert keys is not None
|
|
320
343
|
# currently hard-coded the 4 edges for d==2
|
|
321
344
|
# TODO : find a general & efficient way to sample from the border
|
|
322
345
|
# (facets) of the hypercube in general dim.
|
|
@@ -362,9 +385,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
362
385
|
)
|
|
363
386
|
|
|
364
387
|
def generate_omega_data(
|
|
365
|
-
self, key:
|
|
388
|
+
self, key: PRNGKeyArray, data_size: int | None = None
|
|
366
389
|
) -> tuple[
|
|
367
|
-
|
|
390
|
+
PRNGKeyArray,
|
|
368
391
|
Float[Array, " n dim"],
|
|
369
392
|
]:
|
|
370
393
|
r"""
|
|
@@ -393,18 +416,19 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
393
416
|
omega = jnp.concatenate(xyz_, axis=-1)
|
|
394
417
|
elif self.method in ["uniform", "sobol", "halton"]:
|
|
395
418
|
if self.dim == 1 or self.method in ["sobol", "halton"]:
|
|
396
|
-
key,
|
|
419
|
+
key, subkey = jax.random.split(key, 2)
|
|
420
|
+
omega = self.sample_in_omega_domain([subkey], sample_size=data_size)
|
|
397
421
|
else:
|
|
398
422
|
key, *subkeys = jax.random.split(key, self.dim + 1)
|
|
399
|
-
|
|
423
|
+
omega = self.sample_in_omega_domain(subkeys, sample_size=data_size)
|
|
400
424
|
else:
|
|
401
425
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
402
426
|
return key, omega
|
|
403
427
|
|
|
404
428
|
def generate_omega_border_data(
|
|
405
|
-
self, key:
|
|
429
|
+
self, key: PRNGKeyArray, data_size: int | None = None
|
|
406
430
|
) -> tuple[
|
|
407
|
-
|
|
431
|
+
PRNGKeyArray,
|
|
408
432
|
Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
|
|
409
433
|
]:
|
|
410
434
|
r"""
|
|
@@ -433,7 +457,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
433
457
|
|
|
434
458
|
def _get_omega_operands(
|
|
435
459
|
self,
|
|
436
|
-
) -> tuple[
|
|
460
|
+
) -> tuple[
|
|
461
|
+
PRNGKeyArray, Float[Array, " n dim"], int, int | None, Float[Array, " n"] | None
|
|
462
|
+
]:
|
|
437
463
|
return (
|
|
438
464
|
self.key,
|
|
439
465
|
self.omega,
|
|
@@ -475,7 +501,9 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
475
501
|
# handled above
|
|
476
502
|
)
|
|
477
503
|
new = eqx.tree_at(
|
|
478
|
-
lambda m: (m.key, m.omega, m.curr_omega_idx),
|
|
504
|
+
lambda m: (m.key, m.omega, m.curr_omega_idx), # type: ignore
|
|
505
|
+
self,
|
|
506
|
+
new_attributes,
|
|
479
507
|
)
|
|
480
508
|
|
|
481
509
|
return new, jax.lax.dynamic_slice(
|
|
@@ -487,7 +515,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
487
515
|
def _get_omega_border_operands(
|
|
488
516
|
self,
|
|
489
517
|
) -> tuple[
|
|
490
|
-
|
|
518
|
+
PRNGKeyArray,
|
|
491
519
|
Float[Array, " 1 2"] | Float[Array, " (nb//4) 2 4"] | None,
|
|
492
520
|
int,
|
|
493
521
|
int | None,
|
|
@@ -551,7 +579,7 @@ class CubicMeshPDEStatio(AbstractDataGenerator):
|
|
|
551
579
|
# handled above
|
|
552
580
|
)
|
|
553
581
|
new = eqx.tree_at(
|
|
554
|
-
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx),
|
|
582
|
+
lambda m: (m.key, m.omega_border, m.curr_omega_border_idx), # type: ignore
|
|
555
583
|
self,
|
|
556
584
|
new_attributes,
|
|
557
585
|
)
|