jinns 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jinns/__init__.py +7 -7
- jinns/data/_AbstractDataGenerator.py +1 -1
- jinns/data/_Batchs.py +47 -13
- jinns/data/_CubicMeshPDENonStatio.py +203 -54
- jinns/data/_CubicMeshPDEStatio.py +190 -54
- jinns/data/_DataGeneratorODE.py +48 -22
- jinns/data/_DataGeneratorObservations.py +75 -32
- jinns/data/_DataGeneratorParameter.py +152 -101
- jinns/data/__init__.py +2 -1
- jinns/data/_utils.py +22 -10
- jinns/loss/_DynamicLoss.py +21 -20
- jinns/loss/_DynamicLossAbstract.py +51 -36
- jinns/loss/_LossODE.py +210 -191
- jinns/loss/_LossPDE.py +441 -368
- jinns/loss/_abstract_loss.py +60 -25
- jinns/loss/_loss_components.py +4 -25
- jinns/loss/_loss_utils.py +23 -0
- jinns/loss/_loss_weight_updates.py +6 -7
- jinns/loss/_loss_weights.py +34 -35
- jinns/nn/_abstract_pinn.py +0 -2
- jinns/nn/_hyperpinn.py +34 -23
- jinns/nn/_mlp.py +5 -4
- jinns/nn/_pinn.py +1 -16
- jinns/nn/_ppinn.py +5 -16
- jinns/nn/_save_load.py +11 -4
- jinns/nn/_spinn.py +1 -16
- jinns/nn/_spinn_mlp.py +5 -5
- jinns/nn/_utils.py +33 -38
- jinns/parameters/__init__.py +3 -1
- jinns/parameters/_derivative_keys.py +99 -41
- jinns/parameters/_params.py +58 -25
- jinns/solver/_solve.py +14 -8
- jinns/utils/_DictToModuleMeta.py +66 -0
- jinns/utils/_ItemizableModule.py +19 -0
- jinns/utils/__init__.py +2 -1
- jinns/utils/_types.py +25 -15
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/METADATA +2 -2
- jinns-1.6.0.dist-info/RECORD +57 -0
- jinns-1.5.0.dist-info/RECORD +0 -55
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/WHEEL +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/AUTHORS +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/licenses/LICENSE +0 -0
- {jinns-1.5.0.dist-info → jinns-1.6.0.dist-info}/top_level.txt +0 -0
jinns/__init__.py
CHANGED
|
@@ -1,10 +1,3 @@
|
|
|
1
|
-
# import jinns.data
|
|
2
|
-
# import jinns.loss
|
|
3
|
-
# import jinns.solver
|
|
4
|
-
# import jinns.utils
|
|
5
|
-
# import jinns.experimental
|
|
6
|
-
# import jinns.parameters
|
|
7
|
-
# import jinns.plot
|
|
8
1
|
from jinns import data as data
|
|
9
2
|
from jinns import loss as loss
|
|
10
3
|
from jinns import solver as solver
|
|
@@ -16,3 +9,10 @@ from jinns import nn as nn
|
|
|
16
9
|
from jinns.solver._solve import solve
|
|
17
10
|
|
|
18
11
|
__all__ = ["nn", "solve"]
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
warnings.filterwarnings(
|
|
16
|
+
action="ignore",
|
|
17
|
+
message=r"Using `field\(init=False\)`",
|
|
18
|
+
)
|
jinns/data/_Batchs.py
CHANGED
|
@@ -18,25 +18,59 @@ class ObsBatchDict(TypedDict):
|
|
|
18
18
|
|
|
19
19
|
pinn_in: Float[Array, " obs_batch_size input_dim"]
|
|
20
20
|
val: Float[Array, " obs_batch_size output_dim"]
|
|
21
|
-
eq_params:
|
|
21
|
+
eq_params: (
|
|
22
|
+
eqx.Module | None
|
|
23
|
+
) # None cause sometime user don't provide observed params
|
|
22
24
|
|
|
23
25
|
|
|
24
26
|
class ODEBatch(eqx.Module):
|
|
25
27
|
temporal_batch: Float[Array, " batch_size"]
|
|
26
|
-
param_batch_dict:
|
|
27
|
-
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class PDENonStatioBatch(eqx.Module):
|
|
31
|
-
domain_batch: Float[Array, " batch_size 1+dimension"]
|
|
32
|
-
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
33
|
-
initial_batch: Float[Array, " batch_size dimension"] | None
|
|
34
|
-
param_batch_dict: dict[str, Array] = eqx.field(default=None)
|
|
35
|
-
obs_batch_dict: ObsBatchDict = eqx.field(default=None)
|
|
28
|
+
param_batch_dict: eqx.Module | None = eqx.field(default=None)
|
|
29
|
+
obs_batch_dict: ObsBatchDict | None = eqx.field(default=None)
|
|
36
30
|
|
|
37
31
|
|
|
38
32
|
class PDEStatioBatch(eqx.Module):
|
|
39
33
|
domain_batch: Float[Array, " batch_size dimension"]
|
|
40
34
|
border_batch: Float[Array, " batch_size dimension n_facets"] | None
|
|
41
|
-
param_batch_dict:
|
|
42
|
-
obs_batch_dict: ObsBatchDict
|
|
35
|
+
param_batch_dict: eqx.Module | None
|
|
36
|
+
obs_batch_dict: ObsBatchDict | None
|
|
37
|
+
|
|
38
|
+
# rewrite __init__ to be able to use inheritance for the NonStatio case
|
|
39
|
+
# below. That way PDENonStatioBatch is a subtype of PDEStatioBatch which
|
|
40
|
+
# 1) makes more sense and 2) CubicMeshPDENonStatio.get_batch passes pyright.
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
domain_batch: Float[Array, " batch_size dimension"],
|
|
45
|
+
border_batch: Float[Array, " batch_size dimension n_facets"] | None,
|
|
46
|
+
param_batch_dict: eqx.Module | None = None,
|
|
47
|
+
obs_batch_dict: ObsBatchDict | None = None,
|
|
48
|
+
):
|
|
49
|
+
# TODO: document this ?
|
|
50
|
+
self.domain_batch = domain_batch
|
|
51
|
+
self.border_batch = border_batch
|
|
52
|
+
self.param_batch_dict = param_batch_dict
|
|
53
|
+
self.obs_batch_dict = obs_batch_dict
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PDENonStatioBatch(PDEStatioBatch):
|
|
57
|
+
# TODO: document this ?
|
|
58
|
+
domain_batch: Float[Array, " batch_size 1+dimension"] # Override type
|
|
59
|
+
initial_batch: (
|
|
60
|
+
Float[Array, " batch_size dimension"] | None
|
|
61
|
+
) # why can it be None ? Examples?
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
*,
|
|
66
|
+
domain_batch: Float[Array, " batch_size 1+dimension"],
|
|
67
|
+
border_batch: Float[Array, " batch_size dimension n_facets"] | None,
|
|
68
|
+
initial_batch: Float[Array, " batch_size dimension"] | None,
|
|
69
|
+
param_batch_dict: eqx.Module | None = None,
|
|
70
|
+
obs_batch_dict: ObsBatchDict | None = None,
|
|
71
|
+
):
|
|
72
|
+
self.domain_batch = domain_batch
|
|
73
|
+
self.border_batch = border_batch
|
|
74
|
+
self.initial_batch = initial_batch
|
|
75
|
+
self.param_batch_dict = param_batch_dict
|
|
76
|
+
self.obs_batch_dict = obs_batch_dict
|
|
@@ -7,9 +7,11 @@ from __future__ import (
|
|
|
7
7
|
) # https://docs.python.org/3/library/typing.html#constant
|
|
8
8
|
import warnings
|
|
9
9
|
import equinox as eqx
|
|
10
|
+
import numpy as np
|
|
10
11
|
import jax
|
|
11
12
|
import jax.numpy as jnp
|
|
12
|
-
from
|
|
13
|
+
from scipy.stats import qmc
|
|
14
|
+
from jaxtyping import PRNGKeyArray, Array, Float
|
|
13
15
|
from jinns.data._Batchs import PDENonStatioBatch
|
|
14
16
|
from jinns.data._utils import (
|
|
15
17
|
make_cartesian_product,
|
|
@@ -27,7 +29,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
27
29
|
|
|
28
30
|
Parameters
|
|
29
31
|
----------
|
|
30
|
-
key :
|
|
32
|
+
key : PRNGKeyArray
|
|
31
33
|
Jax random key to sample new time points and to shuffle batches
|
|
32
34
|
n : int
|
|
33
35
|
The number of total $I\times \Omega$ points that will be divided in
|
|
@@ -48,9 +50,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
48
50
|
among the `nb` points. If None, `domain_batch_size` no
|
|
49
51
|
mini-batches are used.
|
|
50
52
|
initial_batch_size : int | None, default=None
|
|
51
|
-
The
|
|
52
|
-
|
|
53
|
-
mini-batches are used.
|
|
53
|
+
The number of randomly selected points among the `ni` initial spatial
|
|
54
|
+
points used for initial condition. If None, no mini-batches are used.
|
|
54
55
|
dim : int
|
|
55
56
|
An integer. Dimension of $\Omega$ domain.
|
|
56
57
|
min_pts : tuple[tuple[Float, Float], ...]
|
|
@@ -65,11 +66,13 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
65
66
|
The minimum value of the time domain to consider
|
|
66
67
|
tmax : float
|
|
67
68
|
The maximum value of the time domain to consider
|
|
68
|
-
method :
|
|
69
|
+
method : Literal["uniform", "grid", "sobol", "halton"], default="uniform"
|
|
69
70
|
Either `grid` or `uniform`, default is `uniform`.
|
|
70
71
|
The method that generates the `nt` time points. `grid` means
|
|
71
72
|
regularly spaced points over the domain. `uniform` means uniformly
|
|
72
|
-
sampled points over the domain
|
|
73
|
+
sampled points over the domain.
|
|
74
|
+
**Note** that Sobol and Halton approaches use scipy modules and will not
|
|
75
|
+
be JIT compatible.
|
|
73
76
|
rar_parameters : Dict[str, int], default=None
|
|
74
77
|
Defaults to None: do not use Residual Adaptative Resampling.
|
|
75
78
|
Otherwise a dictionary with keys
|
|
@@ -90,13 +93,14 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
90
93
|
then corresponds to the initial number of omega points we train the PINN.
|
|
91
94
|
"""
|
|
92
95
|
|
|
93
|
-
tmin:
|
|
94
|
-
tmax:
|
|
95
|
-
ni: int = eqx.field(
|
|
96
|
-
domain_batch_size: int | None = eqx.field(
|
|
97
|
-
initial_batch_size: int | None = eqx.field(
|
|
98
|
-
border_batch_size: int | None = eqx.field(
|
|
96
|
+
tmin: float
|
|
97
|
+
tmax: float
|
|
98
|
+
ni: int = eqx.field(static=True)
|
|
99
|
+
domain_batch_size: int | None = eqx.field(static=True)
|
|
100
|
+
initial_batch_size: int | None = eqx.field(static=True)
|
|
101
|
+
border_batch_size: int | None = eqx.field(static=True)
|
|
99
102
|
|
|
103
|
+
# --- Below fields are not passed as arguments to __init__
|
|
100
104
|
curr_domain_idx: int = eqx.field(init=False)
|
|
101
105
|
curr_initial_idx: int = eqx.field(init=False)
|
|
102
106
|
curr_border_idx: int = eqx.field(init=False)
|
|
@@ -106,13 +110,32 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
106
110
|
)
|
|
107
111
|
initial: Float[Array, " ni dim"] | None = eqx.field(init=False)
|
|
108
112
|
|
|
109
|
-
def
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
tmin: float,
|
|
116
|
+
tmax: float,
|
|
117
|
+
ni: int,
|
|
118
|
+
domain_batch_size: int | None = None,
|
|
119
|
+
initial_batch_size: int | None = None,
|
|
120
|
+
border_batch_size: int | None = None,
|
|
121
|
+
**kwargs, # kwargs for CubicMeshPDEStatio.__init__
|
|
122
|
+
):
|
|
110
123
|
"""
|
|
111
124
|
Note that neither __init__ or __post_init__ are called when udating a
|
|
112
125
|
Module with eqx.tree_at!
|
|
113
126
|
"""
|
|
114
|
-
|
|
115
|
-
|
|
127
|
+
# sanity check
|
|
128
|
+
if ni is None:
|
|
129
|
+
raise ValueError("`ni` cannot be None.")
|
|
130
|
+
|
|
131
|
+
super().__init__(**kwargs)
|
|
132
|
+
self.tmin = tmin
|
|
133
|
+
self.tmax = tmax
|
|
134
|
+
self.ni = ni
|
|
135
|
+
|
|
136
|
+
self.domain_batch_size = domain_batch_size
|
|
137
|
+
self.initial_batch_size = initial_batch_size
|
|
138
|
+
self.border_batch_size = border_batch_size
|
|
116
139
|
|
|
117
140
|
if self.method == "grid":
|
|
118
141
|
# NOTE we must redo the sampling with the square root number of samples
|
|
@@ -140,7 +163,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
140
163
|
)
|
|
141
164
|
self.domain = make_cartesian_product(half_domain_times, half_domain_omega)
|
|
142
165
|
|
|
143
|
-
# NOTE
|
|
166
|
+
# NOTE below re-do CubicMeshPDE.__init__() ? Maybe useless?
|
|
144
167
|
(
|
|
145
168
|
self.n_start,
|
|
146
169
|
self.p,
|
|
@@ -150,9 +173,11 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
150
173
|
elif self.method == "uniform":
|
|
151
174
|
self.key, domain_times = self.generate_time_data(self.key, self.n)
|
|
152
175
|
self.domain = jnp.concatenate([domain_times, self.omega], axis=1)
|
|
176
|
+
elif self.method in ["sobol", "halton"]:
|
|
177
|
+
self.key, self.domain = self.qmc_in_time_omega_domain(self.key, self.n)
|
|
153
178
|
else:
|
|
154
179
|
raise ValueError(
|
|
155
|
-
f'Bad value for method. Got {self.method}, expected "grid" or "uniform"'
|
|
180
|
+
f'Bad value for method. Got {self.method}, expected "grid" or "uniform" or "sobol" or "halton"'
|
|
156
181
|
)
|
|
157
182
|
|
|
158
183
|
if self.domain_batch_size is None:
|
|
@@ -172,7 +197,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
172
197
|
" a multiple of 2xd (the # of faces of a d-dimensional cube)"
|
|
173
198
|
)
|
|
174
199
|
# the check below concern omega_border_batch_size for dim > 1 in
|
|
175
|
-
# super.
|
|
200
|
+
# super.__init__. Here it concerns all dim values since our
|
|
176
201
|
# border_batch is the concatenation or cartesian product with times
|
|
177
202
|
if (
|
|
178
203
|
self.border_batch_size is not None
|
|
@@ -182,21 +207,28 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
182
207
|
"number of points per facets (nb//2*self.dim)"
|
|
183
208
|
" cannot be lower than border batch size"
|
|
184
209
|
)
|
|
185
|
-
self.
|
|
186
|
-
self.key,
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
boundary_times
|
|
191
|
-
|
|
192
|
-
if self.dim == 1:
|
|
193
|
-
self.border = make_cartesian_product(
|
|
194
|
-
boundary_times, self.omega_border[None, None]
|
|
210
|
+
if self.method in ["grid", "uniform"]:
|
|
211
|
+
self.key, boundary_times = self.generate_time_data(
|
|
212
|
+
self.key, self.nb // (2 * self.dim)
|
|
213
|
+
)
|
|
214
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
215
|
+
boundary_times = jnp.repeat(
|
|
216
|
+
boundary_times, self.omega_border.shape[-1], axis=2
|
|
195
217
|
)
|
|
218
|
+
if self.dim == 1:
|
|
219
|
+
self.border = make_cartesian_product(
|
|
220
|
+
boundary_times, self.omega_border[None, None]
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
self.border = jnp.concatenate(
|
|
224
|
+
[boundary_times, self.omega_border], axis=1
|
|
225
|
+
)
|
|
196
226
|
else:
|
|
197
|
-
self.border =
|
|
198
|
-
|
|
227
|
+
self.key, self.border = self.qmc_in_time_omega_border_domain(
|
|
228
|
+
self.key,
|
|
229
|
+
self.nb, # type: ignore (see inside the fun)
|
|
199
230
|
)
|
|
231
|
+
|
|
200
232
|
if self.border_batch_size is None:
|
|
201
233
|
self.curr_border_idx = 0
|
|
202
234
|
else:
|
|
@@ -208,15 +240,31 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
208
240
|
self.border_batch_size = None
|
|
209
241
|
self.curr_border_idx = 0
|
|
210
242
|
|
|
211
|
-
if
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
243
|
+
if ni is not None:
|
|
244
|
+
if self.method == "grid":
|
|
245
|
+
perfect_sq = int(jnp.round(jnp.sqrt(self.ni)) ** 2)
|
|
246
|
+
if self.ni != perfect_sq:
|
|
247
|
+
warnings.warn(
|
|
248
|
+
"Grid sampling is requested in dimension 2 with a non"
|
|
249
|
+
f" perfect square dataset size (self.ni = {self.ni})."
|
|
250
|
+
f" Modifying self.ni to self.ni = {perfect_sq}."
|
|
251
|
+
)
|
|
252
|
+
self.ni = perfect_sq
|
|
253
|
+
if self.method in ["sobol", "halton"]:
|
|
254
|
+
log2_n = jnp.log2(self.ni)
|
|
255
|
+
lower_pow = 2 ** jnp.floor(log2_n)
|
|
256
|
+
higher_pow = 2 ** jnp.ceil(log2_n)
|
|
257
|
+
closest_power_of_two = (
|
|
258
|
+
lower_pow
|
|
259
|
+
if (self.ni - lower_pow) < (higher_pow - self.ni)
|
|
260
|
+
else higher_pow
|
|
218
261
|
)
|
|
219
|
-
|
|
262
|
+
if self.n != closest_power_of_two:
|
|
263
|
+
warnings.warn(
|
|
264
|
+
f"QuasiMonteCarlo sampling with {self.method} requires sample size to be a power fo 2."
|
|
265
|
+
f"Modfiying self.n from {self.ni} to {closest_power_of_two}.",
|
|
266
|
+
)
|
|
267
|
+
self.ni = int(closest_power_of_two)
|
|
220
268
|
self.key, self.initial = self.generate_omega_data(
|
|
221
269
|
self.key, data_size=self.ni
|
|
222
270
|
)
|
|
@@ -235,8 +283,8 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
235
283
|
self.omega_border = None
|
|
236
284
|
|
|
237
285
|
def generate_time_data(
|
|
238
|
-
self, key:
|
|
239
|
-
) -> tuple[
|
|
286
|
+
self, key: PRNGKeyArray, nt: int
|
|
287
|
+
) -> tuple[PRNGKeyArray, Float[Array, " nt 1"]]:
|
|
240
288
|
"""
|
|
241
289
|
Construct a complete set of `nt` time points according to the
|
|
242
290
|
specified `self.method`
|
|
@@ -245,21 +293,122 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
245
293
|
if self.method == "grid":
|
|
246
294
|
partial_times = (self.tmax - self.tmin) / nt
|
|
247
295
|
return key, jnp.arange(self.tmin, self.tmax, partial_times)[:, None]
|
|
248
|
-
|
|
296
|
+
elif self.method in ["uniform", "sobol", "halton"]:
|
|
249
297
|
return key, self.sample_in_time_domain(subkey, nt)
|
|
250
298
|
raise ValueError("Method " + self.method + " is not implemented.")
|
|
251
299
|
|
|
252
|
-
def sample_in_time_domain(
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
300
|
+
def sample_in_time_domain(
|
|
301
|
+
self, key: PRNGKeyArray, nt: int
|
|
302
|
+
) -> Float[Array, " nt 1"]:
|
|
303
|
+
return jax.random.uniform(key, (nt, 1), minval=self.tmin, maxval=self.tmax)
|
|
304
|
+
|
|
305
|
+
def qmc_in_time_omega_domain(
|
|
306
|
+
self, key: PRNGKeyArray, sample_size: int
|
|
307
|
+
) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]]:
|
|
308
|
+
"""
|
|
309
|
+
Because in Quasi-Monte Carlo sampling we cannot concatenate two vectors generated independently
|
|
310
|
+
We generate time and omega samples jointly
|
|
311
|
+
"""
|
|
312
|
+
key, subkey = jax.random.split(key, 2)
|
|
313
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
314
|
+
sampler = qmc_generator(
|
|
315
|
+
d=self.dim + 1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
316
|
+
)
|
|
317
|
+
samples = sampler.random(n=sample_size)
|
|
318
|
+
samples[:, 1:] = qmc.scale(
|
|
319
|
+
samples[:, 1:], l_bounds=self.min_pts, u_bounds=self.max_pts
|
|
320
|
+
) # We scale omega domain to be in (min_pts, max_pts)
|
|
321
|
+
return key, jnp.array(samples)
|
|
322
|
+
|
|
323
|
+
def qmc_in_time_omega_border_domain(
|
|
324
|
+
self, key: PRNGKeyArray, sample_size: int | None = None
|
|
325
|
+
) -> tuple[PRNGKeyArray, Float[Array, "n 1+dim"]] | None:
|
|
326
|
+
"""
|
|
327
|
+
For each facet of the border we generate Quasi-MonteCarlo sequences jointy with time.
|
|
328
|
+
|
|
329
|
+
We need to do some type ignore in this function because we have lost
|
|
330
|
+
the type narrowing from post_init, type checkers only narrow at function level and because we cannot narrow a class attribute.
|
|
331
|
+
"""
|
|
332
|
+
qmc_generator = qmc.Sobol if self.method == "sobol" else qmc.Halton
|
|
333
|
+
sample_size = self.nb if sample_size is None else sample_size
|
|
334
|
+
if sample_size is None:
|
|
335
|
+
return None
|
|
336
|
+
if self.dim == 1:
|
|
337
|
+
key, subkey = jax.random.split(key, 2)
|
|
338
|
+
qmc_seq = qmc_generator(
|
|
339
|
+
d=1, scramble=True, rng=np.random.default_rng(np.uint32(subkey))
|
|
340
|
+
)
|
|
341
|
+
boundary_times = jnp.array(
|
|
342
|
+
qmc_seq.random(self.nb // (2 * self.dim)) # type: ignore
|
|
343
|
+
)
|
|
344
|
+
boundary_times = boundary_times.reshape(-1, 1, 1)
|
|
345
|
+
boundary_times = jnp.repeat(
|
|
346
|
+
boundary_times,
|
|
347
|
+
self.omega_border.shape[-1], # type: ignore
|
|
348
|
+
axis=2,
|
|
349
|
+
)
|
|
350
|
+
return key, make_cartesian_product(
|
|
351
|
+
boundary_times,
|
|
352
|
+
self.omega_border[None, None], # type: ignore
|
|
353
|
+
)
|
|
354
|
+
if self.dim == 2:
|
|
355
|
+
# currently hard-coded the 4 edges for d==2
|
|
356
|
+
# TODO : find a general & efficient way to sample from the border
|
|
357
|
+
# (facets) of the hypercube in general dim.
|
|
358
|
+
key, *subkeys = jax.random.split(key, 5)
|
|
359
|
+
facet_n = sample_size // (2 * self.dim)
|
|
360
|
+
|
|
361
|
+
def generate_qmc_sample(key, min_val, max_val):
|
|
362
|
+
qmc_seq = qmc_generator(
|
|
363
|
+
d=2,
|
|
364
|
+
scramble=True,
|
|
365
|
+
rng=np.random.default_rng(np.uint32(key)),
|
|
366
|
+
)
|
|
367
|
+
u = qmc_seq.random(n=facet_n)
|
|
368
|
+
u[:, 1:2] = qmc.scale(u[:, 1:2], l_bounds=min_val, u_bounds=max_val)
|
|
369
|
+
return jnp.array(u)
|
|
370
|
+
|
|
371
|
+
xmin_sample = generate_qmc_sample(
|
|
372
|
+
subkeys[0], self.min_pts[1], self.max_pts[1]
|
|
373
|
+
) # [t,x,y]
|
|
374
|
+
xmin = jnp.hstack(
|
|
375
|
+
[
|
|
376
|
+
xmin_sample[:, 0:1],
|
|
377
|
+
self.min_pts[0] * jnp.ones((facet_n, 1)),
|
|
378
|
+
xmin_sample[:, 1:2],
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
xmax_sample = generate_qmc_sample(
|
|
382
|
+
subkeys[1], self.min_pts[1], self.max_pts[1]
|
|
383
|
+
)
|
|
384
|
+
xmax = jnp.hstack(
|
|
385
|
+
[
|
|
386
|
+
xmax_sample[:, 0:1],
|
|
387
|
+
self.max_pts[0] * jnp.ones((facet_n, 1)),
|
|
388
|
+
xmax_sample[:, 1:2],
|
|
389
|
+
]
|
|
390
|
+
)
|
|
391
|
+
ymin = jnp.hstack(
|
|
392
|
+
[
|
|
393
|
+
generate_qmc_sample(subkeys[2], self.min_pts[0], self.max_pts[0]),
|
|
394
|
+
self.min_pts[1] * jnp.ones((facet_n, 1)),
|
|
395
|
+
]
|
|
396
|
+
)
|
|
397
|
+
ymax = jnp.hstack(
|
|
398
|
+
[
|
|
399
|
+
generate_qmc_sample(subkeys[3], self.min_pts[0], self.max_pts[0]),
|
|
400
|
+
self.max_pts[1] * jnp.ones((facet_n, 1)),
|
|
401
|
+
]
|
|
402
|
+
)
|
|
403
|
+
return key, jnp.stack([xmin, xmax, ymin, ymax], axis=-1)
|
|
404
|
+
raise NotImplementedError(
|
|
405
|
+
"Generation of the border of a cube in dimension > 2 is not "
|
|
406
|
+
+ f"implemented yet. You are asking for generation in dimension d={self.dim}."
|
|
258
407
|
)
|
|
259
408
|
|
|
260
409
|
def _get_domain_operands(
|
|
261
410
|
self,
|
|
262
|
-
) -> tuple[
|
|
411
|
+
) -> tuple[PRNGKeyArray, Float[Array, " n 1+dim"], int, int | None, Array | None]:
|
|
263
412
|
return (
|
|
264
413
|
self.key,
|
|
265
414
|
self.domain,
|
|
@@ -296,7 +445,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
296
445
|
# handled above
|
|
297
446
|
)
|
|
298
447
|
new = eqx.tree_at(
|
|
299
|
-
lambda m: (m.key, m.domain, m.curr_domain_idx),
|
|
448
|
+
lambda m: (m.key, m.domain, m.curr_domain_idx), # type: ignore
|
|
300
449
|
self,
|
|
301
450
|
new_attributes,
|
|
302
451
|
)
|
|
@@ -309,7 +458,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
309
458
|
def _get_border_operands(
|
|
310
459
|
self,
|
|
311
460
|
) -> tuple[
|
|
312
|
-
|
|
461
|
+
PRNGKeyArray,
|
|
313
462
|
Float[Array, " nb 1+1 2"] | Float[Array, " (nb//4) 2+1 4"] | None,
|
|
314
463
|
int,
|
|
315
464
|
int | None,
|
|
@@ -355,7 +504,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
355
504
|
# handled above
|
|
356
505
|
)
|
|
357
506
|
new = eqx.tree_at(
|
|
358
|
-
lambda m: (m.key, m.border, m.curr_border_idx),
|
|
507
|
+
lambda m: (m.key, m.border, m.curr_border_idx), # type: ignore
|
|
359
508
|
self,
|
|
360
509
|
new_attributes,
|
|
361
510
|
)
|
|
@@ -372,7 +521,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
372
521
|
|
|
373
522
|
def _get_initial_operands(
|
|
374
523
|
self,
|
|
375
|
-
) -> tuple[
|
|
524
|
+
) -> tuple[PRNGKeyArray, Float[Array, " ni dim"] | None, int, int | None, None]:
|
|
376
525
|
return (
|
|
377
526
|
self.key,
|
|
378
527
|
self.initial,
|
|
@@ -401,7 +550,7 @@ class CubicMeshPDENonStatio(CubicMeshPDEStatio):
|
|
|
401
550
|
# handled above
|
|
402
551
|
)
|
|
403
552
|
new = eqx.tree_at(
|
|
404
|
-
lambda m: (m.key, m.initial, m.curr_initial_idx),
|
|
553
|
+
lambda m: (m.key, m.initial, m.curr_initial_idx), # type: ignore
|
|
405
554
|
self,
|
|
406
555
|
new_attributes,
|
|
407
556
|
)
|